chore: block all HTTP requests in CI (#5088)

This commit is contained in:
ZanSara 2023-06-13 14:52:24 +02:00 committed by GitHub
parent 29a6bfe621
commit 65cdf36d72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 427 additions and 218 deletions

View File

@ -2,8 +2,12 @@ import logging
import os
import re
from typing import Tuple
from unittest.mock import patch
from unittest.mock import Mock, patch
from events import Events
from haystack.agents.types import AgentTokenStreamingHandler
from test.conftest import MockRetriever, MockPromptNode
from unittest import mock
import pytest
@ -290,17 +294,6 @@ def test_tool_processes_answer_result_and_document_result():
assert tool._process_result(Document(content="content")) == "content"
def test_invalid_agent_template():
pn = PromptNode()
with pytest.raises(ValueError, match="some_non_existing_template not supported"):
Agent(prompt_node=pn, prompt_template="some_non_existing_template")
# if prompt_template is None, then we'll use zero-shot-react
a = Agent(prompt_node=pn, prompt_template=None)
assert isinstance(a.prompt_template, PromptTemplate)
assert a.prompt_template.name == "zero-shot-react"
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
@ -315,3 +308,25 @@ def test_default_template_order(mock_model, mock_prompt):
a = Agent(prompt_node=pn, prompt_template="translation")
assert a.prompt_template.name == "translation"
@pytest.mark.unit
def test_agent_with_unknown_prompt_template():
prompt_node = Mock()
prompt_node.get_prompt_template.return_value = None
with pytest.raises(ValueError, match="Prompt template 'invalid' not found"):
Agent(prompt_node=prompt_node, prompt_template="invalid")
@pytest.mark.unit
def test_agent_token_streaming_handler():
e = Events("on_new_token")
mock_callback = Mock()
e.on_new_token += mock_callback # register the mock callback to the event
handler = AgentTokenStreamingHandler(events=e)
result = handler("test")
assert result == "test"
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"

View File

@ -1,17 +1,15 @@
import pytest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock
from haystack.agents.conversational import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
from haystack.nodes import PromptNode
from test.conftest import MockPromptNode
@pytest.mark.unit
def test_init():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
prompt_node = MockPromptNode()
agent = ConversationalAgent(prompt_node)
# Test normal case
assert isinstance(agent.memory, ConversationMemory)
@ -25,32 +23,27 @@ def test_init():
@pytest.mark.unit
def test_init_with_summary_memory():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode(default_prompt_template="this is a test")
# Test with summary memory
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
# Test with summary memory
prompt_node = MockPromptNode()
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
@pytest.mark.unit
def test_init_with_no_memory():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
prompt_node = MockPromptNode()
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit
def test_run():
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
prompt_node = PromptNode()
agent = ConversationalAgent(prompt_node)
prompt_node = MockPromptNode()
agent = ConversationalAgent(prompt_node)
# Mock the Agent run method
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")
# Mock the Agent run method
result = agent.run("query")
# empty answer
assert result["answers"][0].answer == ""

View File

@ -1,9 +1,10 @@
import unittest
from typing import Optional, Union, List, Dict, Any
from unittest import mock
import pytest
from haystack import Pipeline, Answer, Document
from haystack import Pipeline, Answer, Document, BaseComponent, MultiLabel
from haystack.agents.base import ToolsManager, Tool
@ -160,3 +161,57 @@ def test_extract_tool_name_and_empty_tool_input(tools_manager):
for example in examples:
tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example)
assert tool_name == "Search" and tool_input == ""
@pytest.mark.unit
def test_node_as_tool():
# test that a component can be used as a tool
class ToolComponent(BaseComponent):
outgoing_edges = 1
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
pass
def run(self, **kwargs):
return "mocked_output"
tool = Tool(name="ToolA", pipeline_or_node=ToolComponent(), description="Tool A Description")
assert tool.run("input") == "mocked_output"
@pytest.mark.unit
def test_tools_manager_exception():
# tests exception raising in tools manager
class ToolComponent(BaseComponent):
outgoing_edges = 1
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
pass
def run(self, **kwargs):
raise Exception("mocked_exception")
fake_llm_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy born"
tool = Tool(name="Search", pipeline_or_node=ToolComponent(), description="Search")
tools_manager = ToolsManager(tools=[tool])
with pytest.raises(Exception):
tools_manager.run_tool(llm_response=fake_llm_response)

View File

@ -8,6 +8,7 @@ from pathlib import Path
import os
import re
from functools import wraps
from unittest.mock import patch
import requests_cache
import responses
@ -847,5 +848,14 @@ def request_blocker(request: pytest.FixtureRequest, monkeypatch):
marker = request.node.get_closest_marker("unit")
if marker is None:
return
monkeypatch.delattr("requests.sessions.Session")
monkeypatch.delattr("requests_cache.session.CachedSession")
def urlopen_mock(self, method, url, *args, **kwargs):
raise RuntimeError(f"The test was about to {method} {self.scheme}://{self.host}{url}")
monkeypatch.setattr("urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock)
@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
yield mock_from_pretrained

View File

@ -15,7 +15,7 @@ from ..conftest import fail_at_version
@pytest.mark.unit
@fail_at_version(1, 18)
def test_seq2seq_deprecation():
def test_seq2seq_deprecation(mock_auto_tokenizer):
with pytest.warns(DeprecationWarning):
try:
Seq2SeqGenerator("non_existing_model/model")

View File

@ -8,10 +8,21 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
from haystack.nodes.prompt.invocation_layer import AnthropicClaudeInvocationLayer
@pytest.fixture
def mock_claude_tokenizer():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer", autospec=True) as mock_tokenizer:
yield mock_tokenizer
@pytest.fixture
def mock_claude_request():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
yield mock_request
@pytest.mark.unit
def test_default_costuctor():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_default_constructor(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
assert layer.api_key == "some_fake_key"
assert layer.max_length == 200
@ -20,7 +31,7 @@ def test_default_costuctor():
@pytest.mark.unit
def test_ignored_kwargs_are_filtered_in_init():
def test_ignored_kwargs_are_filtered_in_init(mock_claude_tokenizer, mock_claude_request):
kwargs = {
"temperature": 1,
"top_p": 5,
@ -30,8 +41,8 @@ def test_ignored_kwargs_are_filtered_in_init():
"stream_handler": DefaultTokenStreamingHandler(),
"unkwnown_args": "this will be filtered out",
}
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs)
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs)
# Verify unexpected kwargs are filtered out
assert len(layer.model_input_kwargs) == 6
@ -45,9 +56,8 @@ def test_ignored_kwargs_are_filtered_in_init():
@pytest.mark.unit
def test_invoke_with_no_kwargs():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_invoke_with_no_kwargs(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
with pytest.raises(ValueError) as e:
layer.invoke()
@ -55,14 +65,12 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry")
def test_invoke_with_prompt_only(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_invoke_with_prompt_only(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
mock_request.return_value = mock_response
mock_claude_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt")
assert len(res) == 1
@ -70,14 +78,13 @@ def test_invoke_with_prompt_only(mock_request):
@pytest.mark.unit
def test_invoke_with_kwargs():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_invoke_with_kwargs(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
mock_request.return_value = mock_response
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=["stop", "here"])
assert len(res) == 1
assert res[0] == "some_result"
@ -92,19 +99,18 @@ def test_invoke_with_kwargs():
"stream": False,
"stop_sequences": ["stop", "here", "\n\nHuman: "],
}
mock_request.assert_called_once()
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data)
mock_invocation_request.assert_called_once()
assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
@pytest.mark.unit
def test_invoke_with_none_stop_words():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_invoke_with_none_stop_words(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake response
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
mock_request.return_value = mock_response
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=None)
assert len(res) == 1
assert res[0] == "some_result"
@ -119,14 +125,13 @@ def test_invoke_with_none_stop_words():
"stream": False,
"stop_sequences": ["\n\nHuman: "],
}
mock_request.assert_called_once()
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data)
mock_invocation_request.assert_called_once()
assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
@pytest.mark.unit
def test_invoke_with_stream():
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
def test_invoke_with_stream(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
# Create a fake streamed response
def mock_iter(self):
@ -141,8 +146,8 @@ def test_invoke_with_stream():
mock_response = Mock(**{"__iter__": mock_iter})
# Verifies expected result is returned
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
mock_request.return_value = mock_response
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt", stream=True)
assert len(res) == 1
@ -150,14 +155,13 @@ def test_invoke_with_stream():
@pytest.mark.unit
def test_invoke_with_custom_stream_handler():
def test_invoke_with_custom_stream_handler(mock_claude_tokenizer, mock_claude_request):
# Create a mock stream handler that always return the same token when called
mock_stream_handler = Mock()
mock_stream_handler.return_value = "token"
# Create a layer with a mocked stream handler
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
# Create a fake streamed response
def mock_iter(self):
@ -171,8 +175,8 @@ def test_invoke_with_custom_stream_handler():
mock_response = Mock(**{"__iter__": mock_iter})
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
mock_request.return_value = mock_response
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
mock_invocation_request.return_value = mock_response
res = layer.invoke(prompt="Some prompt")
assert len(res) == 1
@ -186,7 +190,7 @@ def test_invoke_with_custom_stream_handler():
@pytest.mark.unit
def test_ensure_token_limit_fails_if_called_with_list():
def test_ensure_token_limit_fails_if_called_with_list(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
with pytest.raises(ValueError):
layer._ensure_token_limit(prompt=[])
@ -225,7 +229,7 @@ def test_ensure_token_limit_with_huge_max_length(caplog):
@pytest.mark.unit
def test_supports():
def test_supports(mock_claude_tokenizer, mock_claude_request):
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
assert not layer.supports("claude")

View File

@ -1,14 +1,14 @@
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import Mock
import pytest
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer
@pytest.mark.unit
def test_default_constructor():
def test_default_constructor(mock_auto_tokenizer):
"""
Test that the default constructor sets the correct values
"""
@ -28,7 +28,7 @@ def test_default_constructor():
@pytest.mark.unit
def test_constructor_with_model_kwargs():
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
"""
Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out
@ -43,7 +43,7 @@ def test_constructor_with_model_kwargs():
@pytest.mark.unit
def test_invoke_with_no_kwargs():
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
"""
Test that invoke raises an error if no prompt is provided
"""
@ -54,7 +54,7 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit
def test_invoke_with_stop_words():
def test_invoke_with_stop_words(mock_auto_tokenizer):
"""
Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer
"""
@ -62,98 +62,135 @@ def test_invoke_with_stop_words():
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello"}]}')
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, _ = mock_post.call_args
assert "end_sequences" in called_args[0]
assert called_args[0]["end_sequences"] == stop_words
assert mock_post.called
called_args, _ = mock_post.call_args
assert "end_sequences" in called_args[0]
assert called_args[0]["end_sequences"] == stop_words
@pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False])
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param(using_constructor, stream):
def test_streaming_stream_param_from_init(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
"""
if using_constructor:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
# Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
layer.invoke(prompt="Tell me hello")
if using_constructor:
layer.invoke(prompt="Tell me hello")
else:
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called
_, called_kwargs = mock_post.call_args
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stream is True, then stream is passed as True to _post
if stream:
assert called_kwargs["stream"]
# Check if stream is False, then stream is passed as False to _post
else:
assert not called_kwargs["stream"]
# stream is always passed to _post
assert "stream" in called_kwargs and called_kwargs["stream"]
@pytest.mark.unit
@pytest.mark.parametrize("using_constructor", [True, False])
@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None])
def test_streaming_stream_handler_param(using_constructor, stream_handler):
def test_streaming_stream_param_from_init_no_stream(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
layer.invoke(prompt="Tell me hello")
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert not bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_param_from_invoke(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
layer.invoke(prompt="Tell me hello", stream=True)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_param_from_invoke_no_stream(mock_auto_tokenizer):
"""
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
"""
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
layer.invoke(prompt="Tell me hello", stream=False)
assert mock_post.called
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
assert not bool(called_kwargs["stream"])
@pytest.mark.unit
def test_streaming_stream_handler_param_from_init(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
"""
if using_constructor:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
else:
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
stream_handler = DefaultTokenStreamingHandler()
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
with unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post"
) as mock_post, unittest.mock.patch(
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response"
) as mock_post_stream:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
responses = layer.invoke(prompt="Tell me hello")
if using_constructor:
layer.invoke(prompt="Tell me hello")
else:
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
assert mock_post.called
_, called_kwargs = mock_post.call_args
assert "stream" in called_kwargs
assert bool(called_kwargs["stream"])
assert responses == ["Hello there"]
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
@pytest.mark.unit
def test_streaming_stream_handler_param_from_invoke(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
"""
stream_handler = DefaultTokenStreamingHandler()
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
# stream is always passed to _post
assert "stream" in called_kwargs
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
# Mock the response
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
responses = layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
# if stream_handler is used then stream is always True
if stream_handler:
assert called_kwargs["stream"]
# and stream_handler is passed as an instance of TokenStreamingHandler
called_args, called_kwargs = mock_post_stream.call_args
assert "stream_handler" in called_kwargs
assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler)
# if stream_handler is not used then stream is always False
else:
assert not called_kwargs["stream"]
assert mock_post.called
_, called_kwargs = mock_post.call_args
assert "stream" in called_kwargs
assert bool(called_kwargs["stream"])
assert responses == ["Hello there"]
@pytest.mark.unit
@ -181,13 +218,13 @@ def test_supports():
@pytest.mark.unit
def test_ensure_token_limit_fails_if_called_with_list():
def test_ensure_token_limit_fails_if_called_with_list(mock_auto_tokenizer):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
with pytest.raises(ValueError):
layer._ensure_token_limit(prompt=[])
@pytest.mark.unit
@pytest.mark.integration
def test_ensure_token_limit_with_small_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
res = layer._ensure_token_limit(prompt="Short prompt")
@ -200,7 +237,7 @@ def test_ensure_token_limit_with_small_max_length(caplog):
assert not caplog.records
@pytest.mark.unit
@pytest.mark.integration
def test_ensure_token_limit_with_huge_max_length(caplog):
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
res = layer._ensure_token_limit(prompt="Short prompt")

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, Mock
import pytest
import torch
from torch import device
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
from transformers import AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler
@ -13,8 +13,11 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteri
@pytest.fixture
def mock_pipeline():
# mock transformers pipeline
# model returning some mocked text for pipeline invocation
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline:
mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
pipeline_mock = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
pipeline_mock.side_effect = lambda *args, **kwargs: [{"generated_text": "some mocked text"}]
mocked_pipeline.return_value = pipeline_mock
yield mocked_pipeline
@ -44,7 +47,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
# device is set to cpu by default and device_map is empty
assert kwargs["device"] == device("cpu")
@ -87,7 +90,7 @@ def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
# device is NOT set; device_map is auto because device_map takes precedence over device
assert not kwargs["device"]
@ -110,7 +113,7 @@ def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16
@ -126,7 +129,7 @@ def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16
@ -142,7 +145,7 @@ def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == "auto"
@ -206,9 +209,8 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
"""
Test that the constructor sets the pipeline with the pretrained model (if provided)
"""
# actual model and tokenizer passed to the pipeline
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = Mock()
tokenizer = Mock()
HFLocalInvocationLayer(
model_name_or_path="irrelevant_when_model_is_provided",
@ -221,7 +223,7 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
# mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
# correct tokenizer and model are set as well
assert kwargs["tokenizer"] == tokenizer
@ -239,7 +241,7 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline
assert "some_invalid_kwarg" not in kwargs
@ -258,7 +260,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
HFLocalInvocationLayer(
"google/flan-t5-base",
task_name="text2text-generation",
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
tokenizer=Mock(),
config=Mock(),
revision="1.1",
device="cpu",
@ -271,7 +273,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
# mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args
_, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline
assert "first_invalid_kwarg" not in kwargs
@ -287,7 +289,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
assert len(kwargs) == 13
@pytest.mark.unit
@pytest.mark.integration
def test_text_generation_model():
# test simple prompting with text generation model
# by default, we force the model not return prompt text
@ -303,7 +305,7 @@ def test_text_generation_model():
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
@pytest.mark.unit
@pytest.mark.integration
def test_text_generation_model_via_custom_pretrained_model():
tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing")
model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing")
@ -320,31 +322,29 @@ def test_text_generation_model_via_custom_pretrained_model():
@pytest.mark.unit
def test_streaming_stream_param_in_constructor():
def test_streaming_stream_param_in_constructor(mock_pipeline, mock_get_task):
"""
Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter
"""
layer = HFLocalInvocationLayer(stream=True)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args
_, kwargs = layer.pipe.call_args
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
@pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor():
def test_streaming_stream_handler_param_in_constructor(mock_pipeline, mock_get_task):
"""
Test stream parameter is correctly passed to pipeline invocation
"""
dtsh = DefaultTokenStreamingHandler()
layer = HFLocalInvocationLayer(stream_handler=dtsh)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args
_, kwargs = layer.pipe.call_args
assert "streamer" in kwargs
hf_streamer = kwargs["streamer"]
@ -394,18 +394,17 @@ def test_supports(tmp_path):
@pytest.mark.unit
def test_stop_words_criteria_set():
def test_stop_words_criteria_set(mock_pipeline, mock_get_task):
"""
Test that stop words criteria is correctly set in pipeline invocation
"""
layer = HFLocalInvocationLayer(
model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation"
)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"])
args, kwargs = layer.pipe.call_args
_, kwargs = layer.pipe.call_args
assert "stopping_criteria" in kwargs
assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList)
assert len(kwargs["stopping_criteria"]) == 1
@ -468,7 +467,7 @@ def test_stop_words_not_being_found():
assert word in result[0]
@pytest.mark.unit
@pytest.mark.integration
def test_generation_kwargs_from_constructor():
"""
Test that generation_kwargs are correctly passed to pipeline invocation from constructor
@ -489,7 +488,7 @@ def test_generation_kwargs_from_constructor():
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit
@pytest.mark.integration
def test_generation_kwargs_from_invoke():
"""
Test that generation_kwargs passed to invoke are passed to the underlying HF model
@ -508,3 +507,35 @@ def test_generation_kwargs_from_invoke():
layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
mock_prompt = "I am a tokenized prompt"
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=mock_prompt)
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
result = layer._ensure_token_limit(mock_prompt)
assert result == mock_prompt
@pytest.mark.unit
def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
correct_result = "I am a tokenized prompt of length"
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=correct_result)
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
result = layer._ensure_token_limit("I am a tokenized prompt of length eight")
assert result == correct_result

View File

@ -8,8 +8,23 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
@pytest.fixture
def mock_get_task():
# mock get_task function
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "text2text-generation"
yield mock_get_task
@pytest.fixture
def mock_get_task_invalid():
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
mock_get_task.return_value = "some-nonexistent-type"
yield mock_get_task
@pytest.mark.unit
def test_default_constructor():
def test_default_constructor(mock_auto_tokenizer):
"""
Test that the default constructor sets the correct values
"""
@ -22,7 +37,7 @@ def test_default_constructor():
@pytest.mark.unit
def test_constructor_with_model_kwargs():
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
"""
Test that model_kwargs are correctly set in the constructor
and that model_kwargs_rejected are correctly filtered out
@ -41,7 +56,7 @@ def test_constructor_with_model_kwargs():
@pytest.mark.unit
def test_set_model_max_length():
def test_set_model_max_length(mock_auto_tokenizer):
"""
Test that model max length is set correctly
"""
@ -52,7 +67,7 @@ def test_set_model_max_length():
@pytest.mark.unit
def test_url():
def test_url(mock_auto_tokenizer):
"""
Test that the url is correctly set in the constructor
"""
@ -67,7 +82,7 @@ def test_url():
@pytest.mark.unit
def test_invoke_with_no_kwargs():
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
"""
Test that invoke raises an error if no prompt is provided
"""
@ -78,7 +93,7 @@ def test_invoke_with_no_kwargs():
@pytest.mark.unit
def test_invoke_with_stop_words():
def test_invoke_with_stop_words(mock_auto_tokenizer):
"""
Test stop words are correctly passed to HTTP POST request
"""
@ -95,14 +110,14 @@ def test_invoke_with_stop_words():
assert mock_post.called
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
assert "stop" in called_kwargs["data"]["parameters"]
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_constructor(stream):
def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream):
"""
Test stream parameter is correctly passed to HTTP POST request via constructor
"""
@ -117,7 +132,7 @@ def test_streaming_stream_param_in_constructor(stream):
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
@ -127,7 +142,7 @@ def test_streaming_stream_param_in_constructor(stream):
@pytest.mark.unit
@pytest.mark.parametrize("stream", [True, False])
def test_streaming_stream_param_in_method(stream):
def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream):
"""
Test stream parameter is correctly passed to HTTP POST request via method
"""
@ -140,13 +155,7 @@ def test_streaming_stream_param_in_method(stream):
layer.invoke(prompt="Tell me hello", stream=stream)
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
# Check if stop_words are passed to _post as stop parameter
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
@ -156,7 +165,7 @@ def test_streaming_stream_param_in_method(stream):
@pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor():
def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
@ -175,7 +184,7 @@ def test_streaming_stream_handler_param_in_constructor():
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
@ -183,12 +192,12 @@ def test_streaming_stream_handler_param_in_constructor():
assert called_kwargs["stream"]
# stream_handler is passed as an instance of TokenStreamingHandler
called_args, called_kwargs = mock_post_stream.call_args
called_args, _ = mock_post_stream.call_args
assert isinstance(called_args[1], TokenStreamingHandler)
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_constructor():
def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
"""
@ -202,7 +211,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
layer.invoke(prompt="Tell me hello")
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# stream is always passed to _post
assert "stream" in called_kwargs
@ -212,7 +221,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
@pytest.mark.unit
def test_streaming_stream_handler_param_in_method():
def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
@ -241,7 +250,7 @@ def test_streaming_stream_handler_param_in_method():
@pytest.mark.unit
def test_streaming_no_stream_handler_param_in_method():
def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer):
"""
Test stream_handler parameter is correctly passed to HTTP POST request via method
"""
@ -257,7 +266,7 @@ def test_streaming_no_stream_handler_param_in_method():
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# stream is always correctly passed to _post
assert "stream" in called_kwargs
@ -304,7 +313,7 @@ def test_ensure_token_limit_resize(caplog, model_name_or_path):
@pytest.mark.unit
def test_oasst_prompt_preprocessing():
def test_oasst_prompt_preprocessing(mock_auto_tokenizer):
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
@ -318,7 +327,7 @@ def test_oasst_prompt_preprocessing():
assert result == ["Hello"]
assert mock_post.called
called_args, called_kwargs = mock_post.call_args
_, called_kwargs = mock_post.call_args
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
@ -326,22 +335,20 @@ def test_oasst_prompt_preprocessing():
@pytest.mark.unit
def test_invalid_key():
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
@pytest.mark.unit
def test_invalid_model():
with pytest.raises(ValueError, match="cannot be None or empty string"):
layer = HFInferenceEndpointInvocationLayer("fake_api", "")
HFInferenceEndpointInvocationLayer("fake_api", "")
@pytest.mark.unit
def test_supports():
def test_supports(mock_get_task):
"""
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
"""
# doesn't support fake model
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
# supports google/flan-t5-xxl with api_key
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
@ -353,3 +360,8 @@ def test_supports():
assert HFInferenceEndpointInvocationLayer.supports(
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
)
@pytest.mark.unit
def test_supports_not(mock_get_task_invalid):
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")

View File

@ -9,6 +9,58 @@ from haystack.nodes.prompt.invocation_layer.handlers import (
)
@pytest.mark.unit
def test_prompt_handler_positive():
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
mock_prompt = "I am a tokenized prompt"
with patch(
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
) as mock_tokenizer:
tokenizer_instance = mock_tokenizer.return_value
tokenizer_instance.tokenize.return_value = mock_tokens
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
# Test with a prompt that does not exceed model_max_length when tokenized
result = prompt_handler(mock_prompt)
assert result == {
"resized_prompt": mock_prompt,
"prompt_length": 5,
"new_prompt_length": 5,
"model_max_length": 10,
"max_length": 3,
}
@pytest.mark.unit
def test_prompt_handler_negative():
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
mock_prompt = "I am a tokenized prompt of length"
with patch(
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
) as mock_tokenizer:
tokenizer_instance = mock_tokenizer.return_value
tokenizer_instance.tokenize.return_value = mock_tokens
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
result = prompt_handler(mock_prompt)
assert result == {
"resized_prompt": mock_prompt,
"prompt_length": 8,
"new_prompt_length": 7,
"model_max_length": 10,
"max_length": 3,
}
@pytest.mark.integration
def test_prompt_handler_basics():
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)