mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-11 18:17:53 +00:00
46 lines
1.9 KiB
Python
46 lines
1.9 KiB
Python
![]() |
import pytest
|
||
|
import torch
|
||
|
|
||
|
from haystack.errors import OpenAIError
|
||
|
from haystack.nodes.prompt.prompt_model import PromptModel
|
||
|
|
||
|
|
||
|
@pytest.mark.integration
|
||
|
def test_create_prompt_model():
|
||
|
model = PromptModel("google/flan-t5-small")
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|
||
|
|
||
|
model = PromptModel()
|
||
|
assert model.model_name_or_path == "google/flan-t5-base"
|
||
|
|
||
|
with pytest.raises(OpenAIError):
|
||
|
# davinci selected but no API key provided
|
||
|
model = PromptModel("text-davinci-003")
|
||
|
|
||
|
model = PromptModel("text-davinci-003", api_key="no need to provide a real key")
|
||
|
assert model.model_name_or_path == "text-davinci-003"
|
||
|
|
||
|
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||
|
PromptModel("some-random-model")
|
||
|
|
||
|
# we can also pass model kwargs to the PromptModel
|
||
|
model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}})
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|
||
|
|
||
|
# we can also pass kwargs directly, see HF Pipeline constructor
|
||
|
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16})
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|
||
|
|
||
|
# we can't use device_map auto without accelerate library installed
|
||
|
with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"):
|
||
|
model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"})
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|
||
|
|
||
|
|
||
|
def test_create_prompt_model_dtype():
|
||
|
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "auto"})
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|
||
|
|
||
|
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "torch.bfloat16"})
|
||
|
assert model.model_name_or_path == "google/flan-t5-small"
|