haystack/test/prompt/test_prompt_model.py

46 lines
1.9 KiB
Python
Raw Normal View History

import pytest
import torch
from haystack.errors import OpenAIError
from haystack.nodes.prompt.prompt_model import PromptModel
@pytest.mark.integration
def test_create_prompt_model():
model = PromptModel("google/flan-t5-small")
assert model.model_name_or_path == "google/flan-t5-small"
model = PromptModel()
assert model.model_name_or_path == "google/flan-t5-base"
with pytest.raises(OpenAIError):
# davinci selected but no API key provided
model = PromptModel("text-davinci-003")
model = PromptModel("text-davinci-003", api_key="no need to provide a real key")
assert model.model_name_or_path == "text-davinci-003"
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model")
# we can also pass model kwargs to the PromptModel
model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}})
assert model.model_name_or_path == "google/flan-t5-small"
# we can also pass kwargs directly, see HF Pipeline constructor
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16})
assert model.model_name_or_path == "google/flan-t5-small"
# we can't use device_map auto without accelerate library installed
with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"):
model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"})
assert model.model_name_or_path == "google/flan-t5-small"
def test_create_prompt_model_dtype():
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "auto"})
assert model.model_name_or_path == "google/flan-t5-small"
model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": "torch.bfloat16"})
assert model.model_name_or_path == "google/flan-t5-small"