feat: pass model parameters to HFLocalInvocationLayer via model_kwargs, enabling direct model usage (#4956)

* Simplify HFLocalInvocationLayer, move/add unit tests

* PR feedback

* Better pipeline invocation, add mocked tests

* Minor improvements

* Mock pipeline directly,  unit test updates

* PR feedback, change pytest type to integration

* Mock supports unit test

* add full stop

* PR feedback, improve unit tests

* Add mock_get_task fixture

* Further improve unit tests

* Minor unit test improvement

* Add unit tests, increase coverage

* Add unit tests, increase test coverage

* Small optimization, improve _ensure_token_limit unit test

---------

Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-06-07 13:34:45 +02:00 committed by GitHub
parent eca8f66ffa
commit e3b069620b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 598 additions and 318 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Any
import logging
import os
@ -11,6 +11,7 @@ from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
GenerationConfig,
Pipeline,
)
from transformers.pipelines import get_task
@ -50,7 +51,8 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
includes: "task", "model", "config", "tokenizer", "feature_extractor", "revision", "use_auth_token",
"device_map", "device", "torch_dtype", "trust_remote_code", "model_kwargs", and "pipeline_class".
For more details about pipeline kwargs in general, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
@ -72,81 +74,41 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
self.__class__.__name__,
self.devices[0],
)
# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to Hugging Face pipeline documentation
# Do not use `device_map` AND `device` at the same time as they will conflict
model_input_kwargs = {
key: kwargs[key]
for key in [
"model_kwargs",
"trust_remote_code",
"revision",
"feature_extractor",
"tokenizer",
"config",
"use_fast",
"torch_dtype",
"device_map",
"generation_kwargs",
"model_max_length",
"stream",
"stream_handler",
]
if key in kwargs
}
# flatten model_kwargs one level
if "model_kwargs" in model_input_kwargs:
mkwargs = model_input_kwargs.pop("model_kwargs")
model_input_kwargs.update(mkwargs)
if "device" not in kwargs:
kwargs["device"] = self.devices[0]
# save stream settings and stream_handler for pipeline invocation
self.stream_handler = model_input_kwargs.pop("stream_handler", None)
self.stream = model_input_kwargs.pop("stream", False)
self.stream_handler = kwargs.get("stream_handler", None)
self.stream = kwargs.get("stream", False)
# save generation_kwargs for pipeline invocation
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
model_max_length = model_input_kwargs.pop("model_max_length", None)
torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
model_input_kwargs["torch_dtype"] = torch_dtype_resolved
if len(model_input_kwargs) > 0:
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
self.generation_kwargs = kwargs.get("generation_kwargs", {})
# If task_name is not provided, get the task name from the model name or path (uses HFApi)
if "task_name" in kwargs:
self.task_name = kwargs.get("task_name")
else:
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
self.pipe = pipeline(
task=self.task_name, # task_name is used to determine the pipeline type
model=model_name_or_path,
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
use_auth_token=self.use_auth_token,
model_kwargs=model_input_kwargs,
self.task_name = (
kwargs.get("task_name")
if "task_name" in kwargs
else get_task(model_name_or_path, use_auth_token=use_auth_token)
)
# we check in supports class method if task_name is supported but here we check again as
# we could have gotten the task_name from kwargs
if self.task_name not in ["text2text-generation", "text-generation"]:
raise ValueError(
f"Task name {self.task_name} is not supported. "
f"We only support text2text-generation and text-generation tasks."
)
pipeline_kwargs = self._prepare_pipeline_kwargs(
task=self.task_name, model_name_or_path=model_name_or_path, use_auth_token=use_auth_token, **kwargs
)
# create the transformer pipeline
self.pipe: Pipeline = pipeline(**pipeline_kwargs)
# This is how the default max_length is determined for Text2TextGenerationPipeline shown here
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/pipelines/text2text_generation.html
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length
model_max_length = kwargs.get("model_max_length", None)
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
# embeddings and can accept sequences of more than 512 tokens
if model_max_length is not None:
@ -160,6 +122,37 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
self.pipe.tokenizer.model_max_length,
)
def _prepare_pipeline_kwargs(self, **kwargs) -> Dict[str, Any]:
"""
Sanitizes and prepares the kwargs passed to the transformers pipeline function.
For more details about pipeline kwargs in general, see Hugging Face
[documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
"""
# as device and device_map are mutually exclusive, we set device to None if device_map is provided
device_map = kwargs.get("device_map", None)
device = kwargs.get("device") if device_map is None else None
# prepare torch_dtype for pipeline invocation
torch_dtype = self._extract_torch_dtype(**kwargs)
# and the model (prefer model instance over model_name_or_path str identifier)
model = kwargs.get("model") or kwargs.get("model_name_or_path")
pipeline_kwargs = {
"task": kwargs.get("task", None),
"model": model,
"config": kwargs.get("config", None),
"tokenizer": kwargs.get("tokenizer", None),
"feature_extractor": kwargs.get("feature_extractor", None),
"revision": kwargs.get("revision", None),
"use_auth_token": kwargs.get("use_auth_token", None),
"device_map": device_map,
"device": device,
"torch_dtype": torch_dtype,
"trust_remote_code": kwargs.get("trust_remote_code", False),
"model_kwargs": kwargs.get("model_kwargs", {}),
"pipeline_class": kwargs.get("pipeline_class", None),
}
return pipeline_kwargs
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
@ -172,14 +165,15 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
stop_words = kwargs.pop("stop_words", None)
top_k = kwargs.pop("top_k", None)
# either stream is True (will use default handler) or stream_handler is provided for custom handler
stream = kwargs.get("stream", self.stream) or kwargs.get("stream_handler", self.stream_handler) is not None
stream = kwargs.get("stream", self.stream)
stream_handler = kwargs.get("stream_handler", self.stream_handler)
stream = stream or stream_handler is not None
if kwargs and "prompt" in kwargs:
prompt = kwargs.pop("prompt")
# Consider only Text2TextGenerationPipeline and TextGenerationPipeline relevant, ignore others
# For more details refer to Hugging Face Text2TextGenerationPipeline and TextGenerationPipeline
# documentation
# TODO resolve these kwargs from the pipeline signature
model_input_kwargs = {
key: kwargs[key]
for key in [
@ -227,7 +221,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
model_input_kwargs["max_length"] = self.max_length
if stream:
stream_handler: TokenStreamingHandler = kwargs.pop("stream_handler", DefaultTokenStreamingHandler())
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()
model_input_kwargs["streamer"] = HFTokenStreamingHandler(self.pipe.tokenizer, stream_handler)
output = self.pipe(prompt, **model_input_kwargs)
@ -248,7 +242,8 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
:param prompt: Prompt text to be sent to the generative model.
"""
model_max_length = self.pipe.tokenizer.model_max_length
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
tokenized_prompt = self.pipe.tokenizer.tokenize(prompt)
n_prompt_tokens = len(tokenized_prompt)
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
return prompt
@ -263,20 +258,38 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
model_max_length,
)
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: model_max_length - n_answer_tokens]
tokenized_prompt[: model_max_length - n_answer_tokens]
)
return decoded_string
def _extract_torch_dtype(self, **kwargs) -> Optional[torch.dtype]:
torch_dtype_resolved = None
torch_dtype = kwargs.get("torch_dtype", None)
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
return torch_dtype_resolved
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
task_name: Optional[str] = None
task_name: Optional[str] = kwargs.get("task_name", None)
if os.path.exists(model_name_or_path):
return True
try:
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
task_name = task_name or get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
except RuntimeError:
# This will fail for all non-HF models
return False
@ -300,4 +313,5 @@ class StopWordsCriteria(StoppingCriteria):
self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(torch.isin(input_ids[-1], self.stop_words["input_ids"]))
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)

View File

@ -0,0 +1,510 @@
from unittest.mock import MagicMock, patch, Mock
import pytest
import torch
from torch import device
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteria
@pytest.fixture
def mock_pipeline():
# mock transformers pipeline
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline:
mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
yield mocked_pipeline
@pytest.fixture
def mock_get_task():
# mock get_task function
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task") as mock_get_task:
mock_get_task.return_value = "text2text-generation"
yield mock_get_task
@pytest.mark.unit
def test_constructor_with_invalid_task_name(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with invalid task_name
"""
with pytest.raises(ValueError, match="Task name custom-text2text-generation is not supported"):
HFLocalInvocationLayer("google/flan-t5-base", task_name="custom-text2text-generation")
@pytest.mark.unit
def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with model_name_or_path only
"""
HFLocalInvocationLayer("google/flan-t5-base")
mock_pipeline.assert_called_once()
args, kwargs = mock_pipeline.call_args
# device is set to cpu by default and device_map is empty
assert kwargs["device"] == device("cpu")
assert not kwargs["device_map"]
# correct task and model are set
assert kwargs["task"] == "text2text-generation"
assert kwargs["model"] == "google/flan-t5-base"
# no matter what kwargs we pass or don't pass, there are always 13 predefined kwargs passed to the pipeline
assert len(kwargs) == 13
# and these kwargs are passed to the pipeline
assert list(kwargs.keys()) == [
"task",
"model",
"config",
"tokenizer",
"feature_extractor",
"revision",
"use_auth_token",
"device_map",
"device",
"torch_dtype",
"trust_remote_code",
"model_kwargs",
"pipeline_class",
]
@pytest.mark.unit
def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with model_name_or_path and device_map
"""
layer = HFLocalInvocationLayer("google/flan-t5-base", device="cpu", device_map="auto")
assert layer.pipe == mock_pipeline.return_value
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
# device is NOT set; device_map is auto because device_map takes precedence over device
assert not kwargs["device"]
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
# correct task and model are set as well
assert kwargs["task"] == "text2text-generation"
assert kwargs["model"] == "google/flan-t5-base"
@pytest.mark.unit
def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with torch_dtype parameter using the actual torch object
"""
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype=torch.float16)
assert layer.pipe == mock_pipeline.return_value
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16
@pytest.mark.unit
def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with torch_dtype parameter using the string definition
"""
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="torch.float16")
assert layer.pipe == mock_pipeline.return_value
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == torch.float16
@pytest.mark.unit
def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with torch_dtype parameter using the auto string definition
"""
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="auto")
assert layer.pipe == mock_pipeline.return_value
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
assert kwargs["torch_dtype"] == "auto"
@pytest.mark.unit
def test_constructor_with_invalid_torch_dtype(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with invalid torch_dtype parameter
"""
# we need to provide torch_dtype as a string but with torch. prefix
# this should raise an error
with pytest.raises(ValueError, match="torch_dtype should be a torch.dtype, a string with 'torch.' prefix"):
HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="float16")
@pytest.mark.unit
def test_constructor_with_invalid_torch_dtype_object(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with invalid parameter
"""
# we need to provide torch_dtype as a string but with torch. prefix
# this should raise an error
with pytest.raises(ValueError, match="Invalid torch_dtype value {'invalid': 'object'}"):
HFLocalInvocationLayer("google/flan-t5-base", torch_dtype={"invalid": "object"})
@pytest.mark.integration
def test_ensure_token_limit_positive():
"""
Test that ensure_token_limit works as expected, short prompt text is not changed
"""
prompt_text = "this is a short prompt"
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=10, model_max_length=20)
processed_prompt_text = layer._ensure_token_limit(prompt_text)
assert prompt_text == processed_prompt_text
@pytest.mark.integration
def test_ensure_token_limit_negative(caplog):
"""
Test that ensure_token_limit chops the prompt text if it's longer than the max length allowed by the model
"""
prompt_text = "this is a prompt test that is longer than the max length allowed by the model"
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=10, model_max_length=20)
processed_prompt_text = layer._ensure_token_limit(prompt_text)
assert prompt_text != processed_prompt_text
assert len(processed_prompt_text.split()) <= len(prompt_text.split())
expected_message = (
"The prompt has been truncated from 17 tokens to 10 tokens so that the prompt length and "
"answer length (10 tokens) fit within the max token limit (20 tokens). Shorten the prompt "
"to prevent it from being cut off"
)
assert caplog.records[0].message == expected_message
@pytest.mark.unit
def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
"""
Test that the constructor sets the pipeline with the pretrained model (if provided)
"""
# actual model and tokenizer passed to the pipeline
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
HFLocalInvocationLayer(
model_name_or_path="irrelevant_when_model_is_provided",
model=model,
tokenizer=tokenizer,
task_name="text2text-generation",
)
mock_pipeline.assert_called_once()
# mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args
# correct tokenizer and model are set as well
assert kwargs["tokenizer"] == tokenizer
assert kwargs["model"] == model
@pytest.mark.unit
def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with invalid kwargs
"""
HFLocalInvocationLayer("google/flan-t5-base", some_invalid_kwarg="invalid")
mock_pipeline.assert_called_once()
mock_get_task.assert_called_once()
args, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline
assert "some_invalid_kwarg" not in kwargs
# still our 13 kwargs passed to the pipeline
assert len(kwargs) == 13
@pytest.mark.unit
def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
"""
Test HFLocalInvocationLayer init with various kwargs, make sure all of them are passed to the pipeline
except for the invalid ones
"""
HFLocalInvocationLayer(
"google/flan-t5-base",
task_name="text2text-generation",
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
config=Mock(),
revision="1.1",
device="cpu",
device_map="auto",
first_invalid_kwarg="invalid",
second_invalid_kwarg="invalid",
)
mock_pipeline.assert_called_once()
# mock_get_task is not called as we provided task_name parameter
mock_get_task.assert_not_called()
args, kwargs = mock_pipeline.call_args
# invalid kwargs are ignored and not passed to the pipeline
assert "first_invalid_kwarg" not in kwargs
assert "second_invalid_kwarg" not in kwargs
# correct task and model are set as well
assert kwargs["task"] == "text2text-generation"
assert not kwargs["device"]
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
assert kwargs["revision"] == "1.1"
# still on 13 kwargs passed to the pipeline
assert len(kwargs) == 13
@pytest.mark.unit
def test_text_generation_model():
# test simple prompting with text generation model
# by default, we force the model not return prompt text
# Thus text-generation models can be used with PromptNode
# just like text2text-generation models
layer = HFLocalInvocationLayer("bigscience/bigscience-small-testing")
r = layer.invoke(prompt="Hello big science!")
assert len(r[0]) > 0
# test prompting with parameter to return prompt text as well
# users can use this param to get the prompt text and the generated text
r = layer.invoke(prompt="Hello big science!", return_full_text=True)
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
@pytest.mark.unit
def test_text_generation_model_via_custom_pretrained_model():
tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing")
model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing")
layer = HFLocalInvocationLayer(
"irrelevant_when_model_is_provided", model=model, tokenizer=tokenizer, task_name="text-generation"
)
r = layer.invoke(prompt="Hello big science")
assert len(r[0]) > 0
# test prompting with parameter to return prompt text as well
# users can use this param to get the prompt text and the generated text
r = layer.invoke(prompt="Hello big science", return_full_text=True)
assert len(r[0]) > 0 and r[0].startswith("Hello big science")
@pytest.mark.unit
def test_streaming_stream_param_in_constructor():
"""
Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter
"""
layer = HFLocalInvocationLayer(stream=True)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
@pytest.mark.unit
def test_streaming_stream_handler_param_in_constructor():
"""
Test stream parameter is correctly passed to pipeline invocation
"""
dtsh = DefaultTokenStreamingHandler()
layer = HFLocalInvocationLayer(stream_handler=dtsh)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello")
args, kwargs = layer.pipe.call_args
assert "streamer" in kwargs
hf_streamer = kwargs["streamer"]
# we wrap our TokenStreamingHandler with HFTokenStreamingHandler
assert isinstance(hf_streamer, HFTokenStreamingHandler)
# but under the hood, the wrapped handler is DefaultTokenStreamingHandler we passed
assert isinstance(hf_streamer.token_handler, DefaultTokenStreamingHandler)
assert hf_streamer.token_handler == dtsh
@pytest.mark.unit
def test_supports(tmp_path):
"""
Test that supports returns True correctly for HFLocalInvocationLayer
"""
# mock get_task to avoid remote calls to HF hub
mock_get_task = Mock(return_value="text2text-generation")
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
assert mock_get_task.call_count == 3
# some HF local model directory, let's use the one from test/prompt/invocation_layer
assert HFLocalInvocationLayer.supports(str(tmp_path))
# but not some non text2text-generation or non text-generation model
# i.e image classification model
mock_get_task = Mock(return_value="image-classification")
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
assert not HFLocalInvocationLayer.supports("nateraw/vit-age-classifier")
assert mock_get_task.call_count == 1
# or some POS tagging model
mock_get_task = Mock(return_value="pos-tagging")
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
assert not HFLocalInvocationLayer.supports("vblagoje/bert-english-uncased-finetuned-pos")
assert mock_get_task.call_count == 1
# unless we specify the task name to override the default
# short-circuit the get_task call
assert HFLocalInvocationLayer.supports(
"vblagoje/bert-english-uncased-finetuned-pos", task_name="text2text-generation"
)
@pytest.mark.unit
def test_stop_words_criteria_set():
"""
Test that stop words criteria is correctly set in pipeline invocation
"""
layer = HFLocalInvocationLayer(
model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation"
)
layer.pipe = MagicMock()
layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"])
args, kwargs = layer.pipe.call_args
assert "stopping_criteria" in kwargs
assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList)
assert len(kwargs["stopping_criteria"]) == 1
assert isinstance(kwargs["stopping_criteria"][0], StopWordsCriteria)
@pytest.mark.integration
@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"], ["hello", "good", "health"]])
def test_stop_words_single_token(stop_words):
"""
Test that stop words criteria is used and that it works with single token stop words
"""
# simple test with words not broken down into multiple tokens
default_model = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(default_model)
# each word is broken down into a single token
tokens = tokenizer.tokenize("good health wish")
assert len(tokens) == 3
layer = HFLocalInvocationLayer(model_name_or_path=default_model)
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words)
assert len(result) > 0
assert result[0].startswith("I wish you a")
assert "good" not in result[0]
assert "health" not in result[0]
@pytest.mark.integration
def test_stop_words_multiple_token():
"""
Test that stop words criteria is used and that it works for multi-token words
"""
# complex test with words broken down into multiple tokens
default_model = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(default_model)
# single word unambiguously is broken down into 3 tokens
tokens = tokenizer.tokenize("unambiguously")
assert len(tokens) == 3
layer = HFLocalInvocationLayer(model_name_or_path=default_model)
result = layer.invoke(
prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=["unambiguously"]
)
# yet the stop word is correctly stopped on and removed
assert len(result) > 0
assert result[0].startswith("I wish you")
assert "unambiguously" not in result[0]
assert "good" not in result[0]
assert "health" not in result[0]
@pytest.mark.integration
def test_stop_words_not_being_found():
# simple test with words not broken down into multiple tokens
layer = HFLocalInvocationLayer()
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=["Berlin"])
assert len(result) > 0
for word in "I wish you a good health".split():
assert word in result[0]
@pytest.mark.unit
def test_generation_kwargs_from_constructor():
"""
Test that generation_kwargs are correctly passed to pipeline invocation from constructor
"""
the_question = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying HF model
layer = HFLocalInvocationLayer(generation_kwargs={"do_sample": True})
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
layer.invoke(prompt=the_question)
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "max_length": 100}, {})
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
layer = HFLocalInvocationLayer(generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
layer.invoke(prompt=the_question)
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
@pytest.mark.unit
def test_generation_kwargs_from_invoke():
"""
Test that generation_kwargs passed to invoke are passed to the underlying HF model
"""
the_question = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying HF model
layer = HFLocalInvocationLayer()
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
layer.invoke(prompt=the_question, generation_kwargs={"do_sample": True})
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "max_length": 100}, {})
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
layer = HFLocalInvocationLayer()
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})

View File

@ -36,28 +36,3 @@ def test_construtor_with_custom_model():
def test_constructor_with_no_supported_model():
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptModel("some-random-model")
def create_mock_pipeline(model_name_or_path=None, max_length=100):
return Mock(
**{"model_name_or_path": model_name_or_path},
return_value=Mock(**{"model_name_or_path": model_name_or_path, "tokenizer.model_max_length": max_length}),
)
@pytest.mark.unit
def test_hf_local_invocation_layer_with_task_name():
mock_pipeline = create_mock_pipeline()
mock_get_task = Mock(return_value="dummy_task")
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline", mock_pipeline):
PromptModel(
model_name_or_path="local_model",
max_length=100,
model_kwargs={"task_name": "dummy_task"},
invocation_layer_class=HFLocalInvocationLayer,
)
# checking if get_task is called when task_name is passed to HFLocalInvocationLayer constructor
mock_get_task.assert_not_called()
mock_pipeline.assert_called_once()

View File

@ -196,217 +196,6 @@ def test_invalid_template_params(mock_model, mock_prompthub):
node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.")
@pytest.mark.integration
def test_generation_kwargs_from_prompt_node_init():
the_question = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying HF model
node = PromptNode(model_kwargs={"generation_kwargs": {"do_sample": True}})
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
node(the_question)
mock_call.assert_called_with(
the_question, {}, {"do_sample": True, "num_return_sequences": 1, "num_beams": 1, "max_length": 100}, {}
)
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
node = PromptNode(model_kwargs={"generation_kwargs": GenerationConfig(do_sample=True, top_p=0.9)})
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
node(the_question)
mock_call.assert_called_with(
the_question,
{},
{"do_sample": True, "top_p": 0.9, "num_return_sequences": 1, "num_beams": 1, "max_length": 100},
{},
)
@pytest.mark.integration
def test_generation_kwargs_from_prompt_node_call():
the_question = "What does 42 mean?"
# default node with local HF model
node = PromptNode()
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
# test that generation_kwargs are passed to the underlying HF model
node(the_question, generation_kwargs={"do_sample": True})
mock_call.assert_called_with(
the_question, {}, {"do_sample": True, "num_return_sequences": 1, "num_beams": 1, "max_length": 100}, {}
)
# default node with local HF model
node = PromptNode()
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
node(the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
mock_call.assert_called_with(
the_question,
{},
{"do_sample": True, "top_p": 0.9, "num_return_sequences": 1, "num_beams": 1, "max_length": 100},
{},
)
@pytest.mark.unit
def test_generation_kwargs_are_passed_to_prompt_model_invoke():
prompt = "What does 42 mean?"
mock_prompt_model = Mock()
# This is set so this mock can pass isinstance() checks run in
# PromptNode.__init__
mock_prompt_model.__class__ = PromptModel
mock_prompt_model.invoke.return_value = []
mock_prompt_model._ensure_token_limit.return_value = prompt
# Create PromptNode using the mocked PromptModel
node = PromptNode(model_name_or_path=mock_prompt_model)
node.run(query=prompt, prompt_template="{query}", generation_kwargs={"do_sample": True})
# Verify the do_sample keyword argument is passed to PromptModel.invoke()
mock_prompt_model.invoke.assert_called_once_with(
prompt, stop_words=None, top_k=1, query="What does 42 mean?", do_sample=True
)
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_stop_words(prompt_model):
# TODO: This can be a unit test for StopWordCriteria
skip_test_for_invalid_key(prompt_model)
# test single stop word for both HF and OpenAI
# set stop words in PromptNode
node = PromptNode(prompt_model, stop_words=["capital"])
# with default prompt template and stop words set in PN
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
)
assert r[0] == "What is the" or r[0] == "What city is the"
# test stop words for both HF and OpenAI
# set stop words in PromptNode
node = PromptNode(prompt_model, stop_words=["capital", "Germany"])
# with default prompt template and stop words set in PN
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
)
assert r[0] == "What is the" or r[0] == "What city is the"
# with default prompt template and stop words set in kwargs (overrides PN stop words)
r = node.prompt(
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
documents=["Berlin is the capital of Germany."],
stop_words=None,
)
assert "capital" in r[0] or "Germany" in r[0]
# simple prompting
r = node("Given the context please generate a question. Context: Berlin is the capital of Germany.; Question:")
assert len(r[0]) > 0
assert "capital" not in r[0]
assert "Germany" not in r[0]
# simple prompting with stop words set in kwargs (overrides PN stop words)
r = node(
"Given the context please generate a question. Context: Berlin is the capital of Germany.; Question:",
stop_words=None,
)
assert "capital" in r[0] or "Germany" in r[0]
tt = PromptTemplate("Given the context please generate a question. Context: {documents}; Question:")
# with custom prompt template
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
assert r[0] == "What is the" or r[0] == "What city is the"
# with custom prompt template and stop words set in kwargs (overrides PN stop words)
r = node.prompt(tt, documents=["Berlin is the capital of Germany."], stop_words=None)
assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.integration
def test_prompt_node_model_max_length(caplog):
prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer)
# test that model_max_length is set to 1024
# test that model doesn't truncate the prompt if it is shorter than
# the model max length minus the length of the output
# no warning is raised
node = PromptNode(model_kwargs={"model_max_length": 1024})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert len(caplog.text) <= 0
# test that model_max_length is set to 10
# test that model truncates the prompt if it is longer than the max length (10 tokens)
# a warning is raised
node = PromptNode(model_kwargs={"model_max_length": 10})
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
with caplog.at_level(logging.WARNING):
node.prompt(prompt)
assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_call(mock_model):
"""
Verifies model is created using expected stream handler when calling PromptNode.
"""
mock_handler = Mock()
node = PromptNode()
node.prompt_model = mock_model
node("Irrelevant prompt", stream=True, stream_handler=mock_handler)
# Verify model has been constructed with expected model_kwargs
mock_model.invoke.assert_called_once()
assert mock_model.invoke.call_args_list[0].kwargs["stream_handler"] == mock_handler
@pytest.mark.unit
def test_prompt_node_hf_model_streaming():
# tests that HF streaming handler is passed to the HF pipeline run_single method as "streamer" kwarg with
# the required HF type transformers.generation.streamers.TextStreamer
pn = PromptNode(model_kwargs={"stream_handler": DefaultTokenStreamingHandler()})
with patch.object(pn.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
pn("Irrelevant prompt")
args, kwargs = mock_call.call_args
assert "streamer" in args[2]
assert isinstance(args[2]["streamer"], TextStreamer)
@pytest.mark.unit
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_prompt_node_streaming_handler_on_constructor(mock_model):
"""
Verifies model is created using expected stream handler when constructing PromptNode.
"""
model_kwargs = {"stream_handler": Mock()}
PromptNode(model_kwargs=model_kwargs)
# Verify model has been constructed with expected model_kwargs
mock_model.assert_called_once()
assert mock_model.call_args_list[0].kwargs["model_kwargs"] == model_kwargs
@pytest.mark.skip
@pytest.mark.integration
def test_prompt_node_with_text_generation_model():
# TODO: This is an integration test for HFLocalInvocationLayer
# test simple prompting with text generation model
# by default, we force the model not return prompt text
# Thus text-generation models can be used with PromptNode
# just like text2text-generation models
node = PromptNode("bigscience/bigscience-small-testing")
r = node("Hello big science!")
assert len(r[0]) > 0
# test prompting with parameter to return prompt text as well
# users can use this param to get the prompt text and the generated text
r = node("Hello big science!", return_full_text=True)
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
@pytest.mark.skip
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
@ -1194,14 +983,6 @@ class TestRunBatch:
assert isinstance(result["results"][0][0], str)
@pytest.mark.skip
@pytest.mark.integration
def test_HFLocalInvocationLayer_supports():
# TODO: HFLocalInvocationLayer test, to be moved
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
@pytest.mark.skip
@pytest.mark.integration
def test_chatgpt_direct_prompting(chatgpt_prompt_model):