haystack/test/prompt/test_prompt_model.py
Silvano Cerza c3abf73332
refactor: Rework prompt tests (#4600)
* Rework some PromptNode and PromptModel tests

* Remove duplicate code in PromptNode

* Fix mypy

* Fix test cause of missing fixture

* Revert "Fix mypy"

This reverts commit e530295a06cb260d9a8bd89679534958cb3d9776.

* Revert "Remove duplicate code in PromptNode"

This reverts commit 4a678ae81504dcc78a737372c061d12dc8799639.
2023-04-06 14:47:44 +02:00

39 lines
1.4 KiB
Python

from unittest.mock import patch, Mock
import pytest
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
from .conftest import create_mock_layer_that_supports
@pytest.mark.unit
def test_constructor_with_default_model():
mock_layer = create_mock_layer_that_supports("google/flan-t5-base")
another_layer = create_mock_layer_that_supports("another-model")
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
model = PromptModel()
mock_layer.assert_called_once()
another_layer.assert_not_called()
model.model_invocation_layer.model_name_or_path = "google/flan-t5-base"
@pytest.mark.unit
def test_construtor_with_custom_model():
mock_layer = create_mock_layer_that_supports("some-model")
another_layer = create_mock_layer_that_supports("another-model")
with patch.object(PromptModelInvocationLayer, "invocation_layer_providers", new=[mock_layer, another_layer]):
model = PromptModel("another-model")
mock_layer.assert_not_called()
another_layer.assert_called_once()
model.model_invocation_layer.model_name_or_path = "another-model"
@pytest.mark.unit
def test_constructor_with_no_supported_model():
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model")