mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 03:09:28 +00:00
chore: block all HTTP requests in CI (#5088)
This commit is contained in:
parent
29a6bfe621
commit
65cdf36d72
@ -2,8 +2,12 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
from events import Events
|
||||
|
||||
from haystack.agents.types import AgentTokenStreamingHandler
|
||||
from test.conftest import MockRetriever, MockPromptNode
|
||||
from unittest import mock
|
||||
import pytest
|
||||
@ -290,17 +294,6 @@ def test_tool_processes_answer_result_and_document_result():
|
||||
assert tool._process_result(Document(content="content")) == "content"
|
||||
|
||||
|
||||
def test_invalid_agent_template():
|
||||
pn = PromptNode()
|
||||
with pytest.raises(ValueError, match="some_non_existing_template not supported"):
|
||||
Agent(prompt_node=pn, prompt_template="some_non_existing_template")
|
||||
|
||||
# if prompt_template is None, then we'll use zero-shot-react
|
||||
a = Agent(prompt_node=pn, prompt_template=None)
|
||||
assert isinstance(a.prompt_template, PromptTemplate)
|
||||
assert a.prompt_template.name == "zero-shot-react"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
@ -315,3 +308,25 @@ def test_default_template_order(mock_model, mock_prompt):
|
||||
|
||||
a = Agent(prompt_node=pn, prompt_template="translation")
|
||||
assert a.prompt_template.name == "translation"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_with_unknown_prompt_template():
|
||||
prompt_node = Mock()
|
||||
prompt_node.get_prompt_template.return_value = None
|
||||
with pytest.raises(ValueError, match="Prompt template 'invalid' not found"):
|
||||
Agent(prompt_node=prompt_node, prompt_template="invalid")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_token_streaming_handler():
|
||||
e = Events("on_new_token")
|
||||
|
||||
mock_callback = Mock()
|
||||
e.on_new_token += mock_callback # register the mock callback to the event
|
||||
|
||||
handler = AgentTokenStreamingHandler(events=e)
|
||||
result = handler("test")
|
||||
|
||||
assert result == "test"
|
||||
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
|
||||
from haystack.nodes import PromptNode
|
||||
from test.conftest import MockPromptNode
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init():
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
prompt_node = MockPromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Test normal case
|
||||
assert isinstance(agent.memory, ConversationMemory)
|
||||
@ -25,32 +23,27 @@ def test_init():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_summary_memory():
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode(default_prompt_template="this is a test")
|
||||
# Test with summary memory
|
||||
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
|
||||
assert isinstance(agent.memory, ConversationSummaryMemory)
|
||||
# Test with summary memory
|
||||
prompt_node = MockPromptNode()
|
||||
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
|
||||
assert isinstance(agent.memory, ConversationSummaryMemory)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_no_memory():
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
# Test with no memory
|
||||
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
||||
assert isinstance(agent.memory, NoMemory)
|
||||
prompt_node = MockPromptNode()
|
||||
# Test with no memory
|
||||
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
||||
assert isinstance(agent.memory, NoMemory)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run():
|
||||
with patch("haystack.nodes.prompt.prompt_template.fetch_from_prompthub") as mock_prompthub:
|
||||
mock_prompthub.side_effect = [("This is a test prompt. Use your knowledge to answer this question: {question}")]
|
||||
prompt_node = PromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
prompt_node = MockPromptNode()
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Mock the Agent run method
|
||||
agent.run = MagicMock(return_value="Hello")
|
||||
assert agent.run("query") == "Hello"
|
||||
agent.run.assert_called_once_with("query")
|
||||
# Mock the Agent run method
|
||||
result = agent.run("query")
|
||||
|
||||
# empty answer
|
||||
assert result["answers"][0].answer == ""
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import unittest
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline, Answer, Document
|
||||
from haystack import Pipeline, Answer, Document, BaseComponent, MultiLabel
|
||||
from haystack.agents.base import ToolsManager, Tool
|
||||
|
||||
|
||||
@ -160,3 +161,57 @@ def test_extract_tool_name_and_empty_tool_input(tools_manager):
|
||||
for example in examples:
|
||||
tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example)
|
||||
assert tool_name == "Search" and tool_input == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_as_tool():
|
||||
# test that a component can be used as a tool
|
||||
class ToolComponent(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def run(self, **kwargs):
|
||||
return "mocked_output"
|
||||
|
||||
tool = Tool(name="ToolA", pipeline_or_node=ToolComponent(), description="Tool A Description")
|
||||
assert tool.run("input") == "mocked_output"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tools_manager_exception():
|
||||
# tests exception raising in tools manager
|
||||
class ToolComponent(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
queries: Optional[Union[str, List[str]]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def run(self, **kwargs):
|
||||
raise Exception("mocked_exception")
|
||||
|
||||
fake_llm_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy born"
|
||||
tool = Tool(name="Search", pipeline_or_node=ToolComponent(), description="Search")
|
||||
tools_manager = ToolsManager(tools=[tool])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
tools_manager.run_tool(llm_response=fake_llm_response)
|
||||
|
||||
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
from functools import wraps
|
||||
from unittest.mock import patch
|
||||
|
||||
import requests_cache
|
||||
import responses
|
||||
@ -847,5 +848,14 @@ def request_blocker(request: pytest.FixtureRequest, monkeypatch):
|
||||
marker = request.node.get_closest_marker("unit")
|
||||
if marker is None:
|
||||
return
|
||||
monkeypatch.delattr("requests.sessions.Session")
|
||||
monkeypatch.delattr("requests_cache.session.CachedSession")
|
||||
|
||||
def urlopen_mock(self, method, url, *args, **kwargs):
|
||||
raise RuntimeError(f"The test was about to {method} {self.scheme}://{self.host}{url}")
|
||||
|
||||
monkeypatch.setattr("urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auto_tokenizer():
|
||||
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
|
||||
yield mock_from_pretrained
|
||||
|
||||
@ -15,7 +15,7 @@ from ..conftest import fail_at_version
|
||||
|
||||
@pytest.mark.unit
|
||||
@fail_at_version(1, 18)
|
||||
def test_seq2seq_deprecation():
|
||||
def test_seq2seq_deprecation(mock_auto_tokenizer):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
try:
|
||||
Seq2SeqGenerator("non_existing_model/model")
|
||||
|
||||
@ -8,10 +8,21 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
|
||||
from haystack.nodes.prompt.invocation_layer import AnthropicClaudeInvocationLayer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_tokenizer():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer", autospec=True) as mock_tokenizer:
|
||||
yield mock_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_request():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
|
||||
yield mock_request
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_costuctor():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_default_constructor(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
assert layer.api_key == "some_fake_key"
|
||||
assert layer.max_length == 200
|
||||
@ -20,7 +31,7 @@ def test_default_costuctor():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ignored_kwargs_are_filtered_in_init():
|
||||
def test_ignored_kwargs_are_filtered_in_init(mock_claude_tokenizer, mock_claude_request):
|
||||
kwargs = {
|
||||
"temperature": 1,
|
||||
"top_p": 5,
|
||||
@ -30,8 +41,8 @@ def test_ignored_kwargs_are_filtered_in_init():
|
||||
"stream_handler": DefaultTokenStreamingHandler(),
|
||||
"unkwnown_args": "this will be filtered out",
|
||||
}
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs)
|
||||
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", **kwargs)
|
||||
|
||||
# Verify unexpected kwargs are filtered out
|
||||
assert len(layer.model_input_kwargs) == 6
|
||||
@ -45,9 +56,8 @@ def test_ignored_kwargs_are_filtered_in_init():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_invoke_with_no_kwargs(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
layer.invoke()
|
||||
@ -55,14 +65,12 @@ def test_invoke_with_no_kwargs():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry")
|
||||
def test_invoke_with_prompt_only(mock_request):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_invoke_with_prompt_only(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
# Create a fake response
|
||||
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
|
||||
mock_request.return_value = mock_response
|
||||
mock_claude_request.return_value = mock_response
|
||||
|
||||
res = layer.invoke(prompt="Some prompt")
|
||||
assert len(res) == 1
|
||||
@ -70,14 +78,13 @@ def test_invoke_with_prompt_only(mock_request):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_kwargs():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_invoke_with_kwargs(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
# Create a fake response
|
||||
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
|
||||
mock_request.return_value = mock_response
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
|
||||
mock_invocation_request.return_value = mock_response
|
||||
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=["stop", "here"])
|
||||
assert len(res) == 1
|
||||
assert res[0] == "some_result"
|
||||
@ -92,19 +99,18 @@ def test_invoke_with_kwargs():
|
||||
"stream": False,
|
||||
"stop_sequences": ["stop", "here", "\n\nHuman: "],
|
||||
}
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data)
|
||||
mock_invocation_request.assert_called_once()
|
||||
assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_none_stop_words():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_invoke_with_none_stop_words(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
# Create a fake response
|
||||
mock_response = Mock(**{"status_code": 200, "ok": True, "json.return_value": {"completion": "some_result "}})
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
|
||||
mock_request.return_value = mock_response
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
|
||||
mock_invocation_request.return_value = mock_response
|
||||
res = layer.invoke(prompt="Some prompt", max_length=300, stop_words=None)
|
||||
assert len(res) == 1
|
||||
assert res[0] == "some_result"
|
||||
@ -119,14 +125,13 @@ def test_invoke_with_none_stop_words():
|
||||
"stream": False,
|
||||
"stop_sequences": ["\n\nHuman: "],
|
||||
}
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args.kwargs["data"] == json.dumps(expected_data)
|
||||
mock_invocation_request.assert_called_once()
|
||||
assert mock_invocation_request.call_args.kwargs["data"] == json.dumps(expected_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stream():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
def test_invoke_with_stream(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
# Create a fake streamed response
|
||||
def mock_iter(self):
|
||||
@ -141,8 +146,8 @@ def test_invoke_with_stream():
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
|
||||
# Verifies expected result is returned
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
|
||||
mock_request.return_value = mock_response
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
|
||||
mock_invocation_request.return_value = mock_response
|
||||
res = layer.invoke(prompt="Some prompt", stream=True)
|
||||
|
||||
assert len(res) == 1
|
||||
@ -150,14 +155,13 @@ def test_invoke_with_stream():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_custom_stream_handler():
|
||||
def test_invoke_with_custom_stream_handler(mock_claude_tokenizer, mock_claude_request):
|
||||
# Create a mock stream handler that always return the same token when called
|
||||
mock_stream_handler = Mock()
|
||||
mock_stream_handler.return_value = "token"
|
||||
|
||||
# Create a layer with a mocked stream handler
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.Tokenizer"):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key", stream_handler=mock_stream_handler)
|
||||
|
||||
# Create a fake streamed response
|
||||
def mock_iter(self):
|
||||
@ -171,8 +175,8 @@ def test_invoke_with_custom_stream_handler():
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_request:
|
||||
mock_request.return_value = mock_response
|
||||
with patch("haystack.nodes.prompt.invocation_layer.anthropic_claude.request_with_retry") as mock_invocation_request:
|
||||
mock_invocation_request.return_value = mock_response
|
||||
res = layer.invoke(prompt="Some prompt")
|
||||
|
||||
assert len(res) == 1
|
||||
@ -186,7 +190,7 @@ def test_invoke_with_custom_stream_handler():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_fails_if_called_with_list():
|
||||
def test_ensure_token_limit_fails_if_called_with_list(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
with pytest.raises(ValueError):
|
||||
layer._ensure_token_limit(prompt=[])
|
||||
@ -225,7 +229,7 @@ def test_ensure_token_limit_with_huge_max_length(caplog):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports():
|
||||
def test_supports(mock_claude_tokenizer, mock_claude_request):
|
||||
layer = AnthropicClaudeInvocationLayer(api_key="some_fake_key")
|
||||
|
||||
assert not layer.supports("claude")
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_constructor():
|
||||
def test_default_constructor(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that the default constructor sets the correct values
|
||||
"""
|
||||
@ -28,7 +28,7 @@ def test_default_constructor():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_kwargs():
|
||||
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that model_kwargs are correctly set in the constructor
|
||||
and that model_kwargs_rejected are correctly filtered out
|
||||
@ -43,7 +43,7 @@ def test_constructor_with_model_kwargs():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs():
|
||||
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that invoke raises an error if no prompt is provided
|
||||
"""
|
||||
@ -54,7 +54,7 @@ def test_invoke_with_no_kwargs():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stop_words():
|
||||
def test_invoke_with_stop_words(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stop words are correctly passed from PromptNode to wire in CohereInvocationLayer
|
||||
"""
|
||||
@ -62,98 +62,135 @@ def test_invoke_with_stop_words():
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello"}]}')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, _ = mock_post.call_args
|
||||
assert "end_sequences" in called_args[0]
|
||||
assert called_args[0]["end_sequences"] == stop_words
|
||||
assert mock_post.called
|
||||
called_args, _ = mock_post.call_args
|
||||
assert "end_sequences" in called_args[0]
|
||||
assert called_args[0]["end_sequences"] == stop_words
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("using_constructor", [True, False])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param(using_constructor, stream):
|
||||
def test_streaming_stream_param_from_init(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
|
||||
"""
|
||||
if using_constructor:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=stream)
|
||||
else:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
|
||||
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
if using_constructor:
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
else:
|
||||
layer.invoke(prompt="Tell me hello", stream=stream)
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# Check if stream is True, then stream is passed as True to _post
|
||||
if stream:
|
||||
assert called_kwargs["stream"]
|
||||
# Check if stream is False, then stream is passed as False to _post
|
||||
else:
|
||||
assert not called_kwargs["stream"]
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs and called_kwargs["stream"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("using_constructor", [True, False])
|
||||
@pytest.mark.parametrize("stream_handler", [DefaultTokenStreamingHandler(), None])
|
||||
def test_streaming_stream_handler_param(using_constructor, stream_handler):
|
||||
def test_streaming_stream_param_from_init_no_stream(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from init
|
||||
"""
|
||||
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
assert not bool(called_kwargs["stream"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_param_from_invoke(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
|
||||
"""
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
|
||||
layer.invoke(prompt="Tell me hello", stream=True)
|
||||
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
assert bool(called_kwargs["stream"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_param_from_invoke_no_stream(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream parameter is correctly passed from PromptNode to wire in CohereInvocationLayer from invoke
|
||||
"""
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream=True)
|
||||
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(text='{"generations":[{"text": "Hello there"}]}')
|
||||
layer.invoke(prompt="Tell me hello", stream=False)
|
||||
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
assert not bool(called_kwargs["stream"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_from_init(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
|
||||
"""
|
||||
if using_constructor:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
|
||||
else:
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
stream_handler = DefaultTokenStreamingHandler()
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key", stream_handler=stream_handler)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post"
|
||||
) as mock_post, unittest.mock.patch(
|
||||
"haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._process_streaming_response"
|
||||
) as mock_post_stream:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='{"generations":[{"text": "Hello"}]}')
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
|
||||
responses = layer.invoke(prompt="Tell me hello")
|
||||
|
||||
if using_constructor:
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
else:
|
||||
layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
assert "stream" in called_kwargs
|
||||
assert bool(called_kwargs["stream"])
|
||||
assert responses == ["Hello there"]
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_from_invoke(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly from PromptNode passed to wire in CohereInvocationLayer
|
||||
"""
|
||||
stream_handler = DefaultTokenStreamingHandler()
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="fake_key")
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
with unittest.mock.patch("haystack.nodes.prompt.invocation_layer.CohereInvocationLayer._post") as mock_post:
|
||||
# Mock the response
|
||||
mock_post.return_value = Mock(iter_lines=Mock(return_value=['{"text": "Hello"}', '{"text": " there"}']))
|
||||
responses = layer.invoke(prompt="Tell me hello", stream_handler=stream_handler)
|
||||
|
||||
# if stream_handler is used then stream is always True
|
||||
if stream_handler:
|
||||
assert called_kwargs["stream"]
|
||||
# and stream_handler is passed as an instance of TokenStreamingHandler
|
||||
called_args, called_kwargs = mock_post_stream.call_args
|
||||
assert "stream_handler" in called_kwargs
|
||||
assert isinstance(called_kwargs["stream_handler"], TokenStreamingHandler)
|
||||
# if stream_handler is not used then stream is always False
|
||||
else:
|
||||
assert not called_kwargs["stream"]
|
||||
assert mock_post.called
|
||||
_, called_kwargs = mock_post.call_args
|
||||
assert "stream" in called_kwargs
|
||||
assert bool(called_kwargs["stream"])
|
||||
assert responses == ["Hello there"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -181,13 +218,13 @@ def test_supports():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_fails_if_called_with_list():
|
||||
def test_ensure_token_limit_fails_if_called_with_list(mock_auto_tokenizer):
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key")
|
||||
with pytest.raises(ValueError):
|
||||
layer._ensure_token_limit(prompt=[])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_ensure_token_limit_with_small_max_length(caplog):
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=10)
|
||||
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||
@ -200,7 +237,7 @@ def test_ensure_token_limit_with_small_max_length(caplog):
|
||||
assert not caplog.records
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_ensure_token_limit_with_huge_max_length(caplog):
|
||||
layer = CohereInvocationLayer(model_name_or_path="command", api_key="some_fake_key", max_length=4090)
|
||||
res = layer._ensure_token_limit(prompt="Short prompt")
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch, Mock
|
||||
import pytest
|
||||
import torch
|
||||
from torch import device
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
|
||||
from transformers import AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
@ -13,8 +13,11 @@ from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteri
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
# mock transformers pipeline
|
||||
# model returning some mocked text for pipeline invocation
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline:
|
||||
mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
|
||||
pipeline_mock = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
|
||||
pipeline_mock.side_effect = lambda *args, **kwargs: [{"generated_text": "some mocked text"}]
|
||||
mocked_pipeline.return_value = pipeline_mock
|
||||
yield mocked_pipeline
|
||||
|
||||
|
||||
@ -44,7 +47,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
|
||||
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
|
||||
# device is set to cpu by default and device_map is empty
|
||||
assert kwargs["device"] == device("cpu")
|
||||
@ -87,7 +90,7 @@ def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
|
||||
# device is NOT set; device_map is auto because device_map takes precedence over device
|
||||
assert not kwargs["device"]
|
||||
@ -110,7 +113,7 @@ def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task):
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == torch.float16
|
||||
|
||||
|
||||
@ -126,7 +129,7 @@ def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task):
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == torch.float16
|
||||
|
||||
|
||||
@ -142,7 +145,7 @@ def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task):
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == "auto"
|
||||
|
||||
|
||||
@ -206,9 +209,8 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test that the constructor sets the pipeline with the pretrained model (if provided)
|
||||
"""
|
||||
# actual model and tokenizer passed to the pipeline
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = Mock()
|
||||
tokenizer = Mock()
|
||||
|
||||
HFLocalInvocationLayer(
|
||||
model_name_or_path="irrelevant_when_model_is_provided",
|
||||
@ -221,7 +223,7 @@ def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
|
||||
# mock_get_task is not called as we provided task_name parameter
|
||||
mock_get_task.assert_not_called()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
|
||||
# correct tokenizer and model are set as well
|
||||
assert kwargs["tokenizer"] == tokenizer
|
||||
@ -239,7 +241,7 @@ def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
|
||||
# invalid kwargs are ignored and not passed to the pipeline
|
||||
assert "some_invalid_kwarg" not in kwargs
|
||||
@ -258,7 +260,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
|
||||
HFLocalInvocationLayer(
|
||||
"google/flan-t5-base",
|
||||
task_name="text2text-generation",
|
||||
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
|
||||
tokenizer=Mock(),
|
||||
config=Mock(),
|
||||
revision="1.1",
|
||||
device="cpu",
|
||||
@ -271,7 +273,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
|
||||
# mock_get_task is not called as we provided task_name parameter
|
||||
mock_get_task.assert_not_called()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
_, kwargs = mock_pipeline.call_args
|
||||
|
||||
# invalid kwargs are ignored and not passed to the pipeline
|
||||
assert "first_invalid_kwarg" not in kwargs
|
||||
@ -287,7 +289,7 @@ def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
|
||||
assert len(kwargs) == 13
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_text_generation_model():
|
||||
# test simple prompting with text generation model
|
||||
# by default, we force the model not return prompt text
|
||||
@ -303,7 +305,7 @@ def test_text_generation_model():
|
||||
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_text_generation_model_via_custom_pretrained_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing")
|
||||
model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing")
|
||||
@ -320,31 +322,29 @@ def test_text_generation_model_via_custom_pretrained_model():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_param_in_constructor():
|
||||
def test_streaming_stream_param_in_constructor(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter
|
||||
"""
|
||||
layer = HFLocalInvocationLayer(stream=True)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
_, kwargs = layer.pipe.call_args
|
||||
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_constructor():
|
||||
def test_streaming_stream_handler_param_in_constructor(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test stream parameter is correctly passed to pipeline invocation
|
||||
"""
|
||||
dtsh = DefaultTokenStreamingHandler()
|
||||
layer = HFLocalInvocationLayer(stream_handler=dtsh)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
_, kwargs = layer.pipe.call_args
|
||||
assert "streamer" in kwargs
|
||||
hf_streamer = kwargs["streamer"]
|
||||
|
||||
@ -394,18 +394,17 @@ def test_supports(tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stop_words_criteria_set():
|
||||
def test_stop_words_criteria_set(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test that stop words criteria is correctly set in pipeline invocation
|
||||
"""
|
||||
layer = HFLocalInvocationLayer(
|
||||
model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation"
|
||||
)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"])
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
_, kwargs = layer.pipe.call_args
|
||||
assert "stopping_criteria" in kwargs
|
||||
assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList)
|
||||
assert len(kwargs["stopping_criteria"]) == 1
|
||||
@ -468,7 +467,7 @@ def test_stop_words_not_being_found():
|
||||
assert word in result[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_generation_kwargs_from_constructor():
|
||||
"""
|
||||
Test that generation_kwargs are correctly passed to pipeline invocation from constructor
|
||||
@ -489,7 +488,7 @@ def test_generation_kwargs_from_constructor():
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.integration
|
||||
def test_generation_kwargs_from_invoke():
|
||||
"""
|
||||
Test that generation_kwargs passed to invoke are passed to the underlying HF model
|
||||
@ -508,3 +507,35 @@ def test_generation_kwargs_from_invoke():
|
||||
layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
|
||||
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
|
||||
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
|
||||
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
|
||||
mock_prompt = "I am a tokenized prompt"
|
||||
|
||||
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
|
||||
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=mock_prompt)
|
||||
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
|
||||
result = layer._ensure_token_limit(mock_prompt)
|
||||
|
||||
assert result == mock_prompt
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
|
||||
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
|
||||
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
|
||||
correct_result = "I am a tokenized prompt of length"
|
||||
|
||||
mock_auto_tokenizer.tokenize = Mock(return_value=mock_tokens)
|
||||
mock_auto_tokenizer.convert_tokens_to_string = Mock(return_value=correct_result)
|
||||
mock_pipeline.return_value.tokenizer = mock_auto_tokenizer
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=3, model_max_length=10)
|
||||
result = layer._ensure_token_limit("I am a tokenized prompt of length eight")
|
||||
|
||||
assert result == correct_result
|
||||
|
||||
@ -8,8 +8,23 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
|
||||
from haystack.nodes.prompt.invocation_layer import HFInferenceEndpointInvocationLayer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_task():
|
||||
# mock get_task function
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
|
||||
mock_get_task.return_value = "text2text-generation"
|
||||
yield mock_get_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_task_invalid():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face_inference.get_task") as mock_get_task:
|
||||
mock_get_task.return_value = "some-nonexistent-type"
|
||||
yield mock_get_task
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_constructor():
|
||||
def test_default_constructor(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that the default constructor sets the correct values
|
||||
"""
|
||||
@ -22,7 +37,7 @@ def test_default_constructor():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_kwargs():
|
||||
def test_constructor_with_model_kwargs(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that model_kwargs are correctly set in the constructor
|
||||
and that model_kwargs_rejected are correctly filtered out
|
||||
@ -41,7 +56,7 @@ def test_constructor_with_model_kwargs():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_set_model_max_length():
|
||||
def test_set_model_max_length(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that model max length is set correctly
|
||||
"""
|
||||
@ -52,7 +67,7 @@ def test_set_model_max_length():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_url():
|
||||
def test_url(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that the url is correctly set in the constructor
|
||||
"""
|
||||
@ -67,7 +82,7 @@ def test_url():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs():
|
||||
def test_invoke_with_no_kwargs(mock_auto_tokenizer):
|
||||
"""
|
||||
Test that invoke raises an error if no prompt is provided
|
||||
"""
|
||||
@ -78,7 +93,7 @@ def test_invoke_with_no_kwargs():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stop_words():
|
||||
def test_invoke_with_stop_words(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stop words are correctly passed to HTTP POST request
|
||||
"""
|
||||
@ -95,14 +110,14 @@ def test_invoke_with_stop_words():
|
||||
assert mock_post.called
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
assert "stop" in called_kwargs["data"]["parameters"]
|
||||
assert called_kwargs["data"]["parameters"]["stop"] == stop_words
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param_in_constructor(stream):
|
||||
def test_streaming_stream_param_in_constructor(mock_auto_tokenizer, stream):
|
||||
"""
|
||||
Test stream parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
@ -117,7 +132,7 @@ def test_streaming_stream_param_in_constructor(stream):
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
@ -127,7 +142,7 @@ def test_streaming_stream_param_in_constructor(stream):
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_streaming_stream_param_in_method(stream):
|
||||
def test_streaming_stream_param_in_method(mock_auto_tokenizer, stream):
|
||||
"""
|
||||
Test stream parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
@ -140,13 +155,7 @@ def test_streaming_stream_param_in_method(stream):
|
||||
layer.invoke(prompt="Tell me hello", stream=stream)
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
|
||||
# Check if stop_words are passed to _post as stop parameter
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
@ -156,7 +165,7 @@ def test_streaming_stream_param_in_method(stream):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_constructor():
|
||||
def test_streaming_stream_handler_param_in_constructor(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
@ -175,7 +184,7 @@ def test_streaming_stream_handler_param_in_constructor():
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
@ -183,12 +192,12 @@ def test_streaming_stream_handler_param_in_constructor():
|
||||
assert called_kwargs["stream"]
|
||||
|
||||
# stream_handler is passed as an instance of TokenStreamingHandler
|
||||
called_args, called_kwargs = mock_post_stream.call_args
|
||||
called_args, _ = mock_post_stream.call_args
|
||||
assert isinstance(called_args[1], TokenStreamingHandler)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_no_stream_handler_param_in_constructor():
|
||||
def test_streaming_no_stream_handler_param_in_constructor(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via constructor
|
||||
"""
|
||||
@ -202,7 +211,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
assert mock_post.called
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
@ -212,7 +221,7 @@ def test_streaming_no_stream_handler_param_in_constructor():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_method():
|
||||
def test_streaming_stream_handler_param_in_method(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
@ -241,7 +250,7 @@ def test_streaming_stream_handler_param_in_method():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_no_stream_handler_param_in_method():
|
||||
def test_streaming_no_stream_handler_param_in_method(mock_auto_tokenizer):
|
||||
"""
|
||||
Test stream_handler parameter is correctly passed to HTTP POST request via method
|
||||
"""
|
||||
@ -257,7 +266,7 @@ def test_streaming_no_stream_handler_param_in_method():
|
||||
|
||||
assert mock_post.called
|
||||
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
|
||||
# stream is always correctly passed to _post
|
||||
assert "stream" in called_kwargs
|
||||
@ -304,7 +313,7 @@ def test_ensure_token_limit_resize(caplog, model_name_or_path):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_oasst_prompt_preprocessing():
|
||||
def test_oasst_prompt_preprocessing(mock_auto_tokenizer):
|
||||
model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
|
||||
|
||||
layer = HFInferenceEndpointInvocationLayer("fake_api_key", model_name)
|
||||
@ -318,7 +327,7 @@ def test_oasst_prompt_preprocessing():
|
||||
assert result == ["Hello"]
|
||||
assert mock_post.called
|
||||
|
||||
called_args, called_kwargs = mock_post.call_args
|
||||
_, called_kwargs = mock_post.call_args
|
||||
# OpenAssistant/oasst-sft-1-pythia-12b prompts are preprocessed and wrapped in tokens below
|
||||
assert called_kwargs["data"]["inputs"] == "<|prompter|>Tell me hello<|endoftext|><|assistant|>"
|
||||
|
||||
@ -326,22 +335,20 @@ def test_oasst_prompt_preprocessing():
|
||||
@pytest.mark.unit
|
||||
def test_invalid_key():
|
||||
with pytest.raises(ValueError, match="must be a valid Hugging Face token"):
|
||||
layer = HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
|
||||
HFInferenceEndpointInvocationLayer("", "irrelevant_model_name")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_model():
|
||||
with pytest.raises(ValueError, match="cannot be None or empty string"):
|
||||
layer = HFInferenceEndpointInvocationLayer("fake_api", "")
|
||||
HFInferenceEndpointInvocationLayer("fake_api", "")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports():
|
||||
def test_supports(mock_get_task):
|
||||
"""
|
||||
Test that supports returns True correctly for HFInferenceEndpointInvocationLayer
|
||||
"""
|
||||
# doesn't support fake model
|
||||
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
|
||||
|
||||
# supports google/flan-t5-xxl with api_key
|
||||
assert HFInferenceEndpointInvocationLayer.supports("google/flan-t5-xxl", api_key="fake_key")
|
||||
@ -353,3 +360,8 @@ def test_supports():
|
||||
assert HFInferenceEndpointInvocationLayer.supports(
|
||||
"https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud", api_key="fake_key"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_not(mock_get_task_invalid):
|
||||
assert not HFInferenceEndpointInvocationLayer.supports("fake_model", api_key="fake_key")
|
||||
|
||||
@ -9,6 +9,58 @@ from haystack.nodes.prompt.invocation_layer.handlers import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_handler_positive():
|
||||
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
|
||||
mock_tokens = ["I", "am", "a", "tokenized", "prompt"]
|
||||
mock_prompt = "I am a tokenized prompt"
|
||||
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
|
||||
) as mock_tokenizer:
|
||||
tokenizer_instance = mock_tokenizer.return_value
|
||||
tokenizer_instance.tokenize.return_value = mock_tokens
|
||||
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
|
||||
|
||||
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
|
||||
|
||||
# Test with a prompt that does not exceed model_max_length when tokenized
|
||||
result = prompt_handler(mock_prompt)
|
||||
|
||||
assert result == {
|
||||
"resized_prompt": mock_prompt,
|
||||
"prompt_length": 5,
|
||||
"new_prompt_length": 5,
|
||||
"model_max_length": 10,
|
||||
"max_length": 3,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_handler_negative():
|
||||
# prompt of length 8 + max_length of 3 = 11, which is more than model_max_length of 10, so we resize to 7
|
||||
mock_tokens = ["I", "am", "a", "tokenized", "prompt", "of", "length", "eight"]
|
||||
mock_prompt = "I am a tokenized prompt of length"
|
||||
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained", autospec=True
|
||||
) as mock_tokenizer:
|
||||
tokenizer_instance = mock_tokenizer.return_value
|
||||
tokenizer_instance.tokenize.return_value = mock_tokens
|
||||
tokenizer_instance.convert_tokens_to_string.return_value = mock_prompt
|
||||
|
||||
prompt_handler = DefaultPromptHandler("model_path", 10, 3)
|
||||
result = prompt_handler(mock_prompt)
|
||||
|
||||
assert result == {
|
||||
"resized_prompt": mock_prompt,
|
||||
"prompt_length": 8,
|
||||
"new_prompt_length": 7,
|
||||
"model_max_length": 10,
|
||||
"max_length": 3,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_prompt_handler_basics():
|
||||
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user