mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-11 10:07:50 +00:00

* Simplify HFLocalInvocationLayer, move/add unit tests * PR feedback * Better pipeline invocation, add mocked tests * Minor improvements * Mock pipeline directly, unit test updates * PR feedback, change pytest type to integration * Mock supports unit test * add full stop * PR feedback, improve unit tests * Add mock_get_task fixture * Further improve unit tests * Minor unit test improvement * Add unit tests, increase coverage * Add unit tests, increase test coverage * Small optimization, improve _ensure_token_limit unit test --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com>
39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
from unittest.mock import patch, Mock
|
|
|
|
import pytest
|
|
|
|
from haystack.nodes.prompt.prompt_model import PromptModel
|
|
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer
|
|
|
|
from .conftest import create_mock_layer_that_supports
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_constructor_with_default_model():
|
|
mock_layer = create_mock_layer_that_supports("google/flan-t5-base")
|
|
another_layer = create_mock_layer_that_supports("another-model")
|
|
|
|
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
|
|
model = PromptModel()
|
|
mock_layer.assert_called_once()
|
|
another_layer.assert_not_called()
|
|
model.model_invocation_layer.model_name_or_path = "google/flan-t5-base"
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_construtor_with_custom_model():
|
|
mock_layer = create_mock_layer_that_supports("some-model")
|
|
another_layer = create_mock_layer_that_supports("another-model")
|
|
|
|
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
|
|
model = PromptModel("another-model")
|
|
mock_layer.assert_not_called()
|
|
another_layer.assert_called_once()
|
|
model.model_invocation_layer.model_name_or_path = "another-model"
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_constructor_with_no_supported_model():
|
|
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
|
PromptModel("some-random-model")
|