mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 21:48:52 +00:00
feat: pass model parameters to HFLocalInvocationLayer via model_kwargs, enabling direct model usage (#4956)
* Simplify HFLocalInvocationLayer, move/add unit tests * PR feedback * Better pipeline invocation, add mocked tests * Minor improvements * Mock pipeline directly, unit test updates * PR feedback, change pytest type to integration * Mock supports unit test * add full stop * PR feedback, improve unit tests * Add mock_get_task fixture * Further improve unit tests * Minor unit test improvement * Add unit tests, increase coverage * Add unit tests, increase test coverage * Small optimization, improve _ensure_token_limit unit test --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
eca8f66ffa
commit
e3b069620b
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union, List, Dict
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -11,6 +11,7 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
GenerationConfig,
|
||||
Pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
|
||||
@ -50,7 +51,8 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
|
||||
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
|
||||
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
|
||||
includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||
includes: "task", "model", "config", "tokenizer", "feature_extractor", "revision", "use_auth_token",
|
||||
"device_map", "device", "torch_dtype", "trust_remote_code", "model_kwargs", and "pipeline_class".
|
||||
For more details about pipeline kwargs in general, see
|
||||
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||
|
||||
@ -72,81 +74,41 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
self.__class__.__name__,
|
||||
self.devices[0],
|
||||
)
|
||||
|
||||
# Due to reflective construction of all invocation layers we might receive some
|
||||
# unknown kwargs, so we need to take only the relevant.
|
||||
# For more details refer to Hugging Face pipeline documentation
|
||||
# Do not use `device_map` AND `device` at the same time as they will conflict
|
||||
model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"model_kwargs",
|
||||
"trust_remote_code",
|
||||
"revision",
|
||||
"feature_extractor",
|
||||
"tokenizer",
|
||||
"config",
|
||||
"use_fast",
|
||||
"torch_dtype",
|
||||
"device_map",
|
||||
"generation_kwargs",
|
||||
"model_max_length",
|
||||
"stream",
|
||||
"stream_handler",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
# flatten model_kwargs one level
|
||||
if "model_kwargs" in model_input_kwargs:
|
||||
mkwargs = model_input_kwargs.pop("model_kwargs")
|
||||
model_input_kwargs.update(mkwargs)
|
||||
if "device" not in kwargs:
|
||||
kwargs["device"] = self.devices[0]
|
||||
|
||||
# save stream settings and stream_handler for pipeline invocation
|
||||
self.stream_handler = model_input_kwargs.pop("stream_handler", None)
|
||||
self.stream = model_input_kwargs.pop("stream", False)
|
||||
self.stream_handler = kwargs.get("stream_handler", None)
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
# save generation_kwargs for pipeline invocation
|
||||
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
|
||||
model_max_length = model_input_kwargs.pop("model_max_length", None)
|
||||
|
||||
torch_dtype = model_input_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if "torch." in torch_dtype:
|
||||
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
|
||||
elif torch_dtype == "auto":
|
||||
torch_dtype_resolved = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
|
||||
)
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype_resolved = torch_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
|
||||
model_input_kwargs["torch_dtype"] = torch_dtype_resolved
|
||||
|
||||
if len(model_input_kwargs) > 0:
|
||||
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
|
||||
self.generation_kwargs = kwargs.get("generation_kwargs", {})
|
||||
|
||||
# If task_name is not provided, get the task name from the model name or path (uses HFApi)
|
||||
if "task_name" in kwargs:
|
||||
self.task_name = kwargs.get("task_name")
|
||||
else:
|
||||
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
|
||||
|
||||
self.pipe = pipeline(
|
||||
task=self.task_name, # task_name is used to determine the pipeline type
|
||||
model=model_name_or_path,
|
||||
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
|
||||
use_auth_token=self.use_auth_token,
|
||||
model_kwargs=model_input_kwargs,
|
||||
self.task_name = (
|
||||
kwargs.get("task_name")
|
||||
if "task_name" in kwargs
|
||||
else get_task(model_name_or_path, use_auth_token=use_auth_token)
|
||||
)
|
||||
# we check in supports class method if task_name is supported but here we check again as
|
||||
# we could have gotten the task_name from kwargs
|
||||
if self.task_name not in ["text2text-generation", "text-generation"]:
|
||||
raise ValueError(
|
||||
f"Task name {self.task_name} is not supported. "
|
||||
f"We only support text2text-generation and text-generation tasks."
|
||||
)
|
||||
pipeline_kwargs = self._prepare_pipeline_kwargs(
|
||||
task=self.task_name, model_name_or_path=model_name_or_path, use_auth_token=use_auth_token, **kwargs
|
||||
)
|
||||
# create the transformer pipeline
|
||||
self.pipe: Pipeline = pipeline(**pipeline_kwargs)
|
||||
|
||||
# This is how the default max_length is determined for Text2TextGenerationPipeline shown here
|
||||
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/pipelines/text2text_generation.html
|
||||
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
|
||||
self.max_length = max_length or self.pipe.model.config.max_length
|
||||
|
||||
model_max_length = kwargs.get("model_max_length", None)
|
||||
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
|
||||
# embeddings and can accept sequences of more than 512 tokens
|
||||
if model_max_length is not None:
|
||||
@ -160,6 +122,37 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
self.pipe.tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
def _prepare_pipeline_kwargs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitizes and prepares the kwargs passed to the transformers pipeline function.
|
||||
For more details about pipeline kwargs in general, see Hugging Face
|
||||
[documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||
"""
|
||||
# as device and device_map are mutually exclusive, we set device to None if device_map is provided
|
||||
device_map = kwargs.get("device_map", None)
|
||||
device = kwargs.get("device") if device_map is None else None
|
||||
# prepare torch_dtype for pipeline invocation
|
||||
torch_dtype = self._extract_torch_dtype(**kwargs)
|
||||
# and the model (prefer model instance over model_name_or_path str identifier)
|
||||
model = kwargs.get("model") or kwargs.get("model_name_or_path")
|
||||
|
||||
pipeline_kwargs = {
|
||||
"task": kwargs.get("task", None),
|
||||
"model": model,
|
||||
"config": kwargs.get("config", None),
|
||||
"tokenizer": kwargs.get("tokenizer", None),
|
||||
"feature_extractor": kwargs.get("feature_extractor", None),
|
||||
"revision": kwargs.get("revision", None),
|
||||
"use_auth_token": kwargs.get("use_auth_token", None),
|
||||
"device_map": device_map,
|
||||
"device": device,
|
||||
"torch_dtype": torch_dtype,
|
||||
"trust_remote_code": kwargs.get("trust_remote_code", False),
|
||||
"model_kwargs": kwargs.get("model_kwargs", {}),
|
||||
"pipeline_class": kwargs.get("pipeline_class", None),
|
||||
}
|
||||
return pipeline_kwargs
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
|
||||
@ -172,14 +165,15 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
stop_words = kwargs.pop("stop_words", None)
|
||||
top_k = kwargs.pop("top_k", None)
|
||||
# either stream is True (will use default handler) or stream_handler is provided for custom handler
|
||||
stream = kwargs.get("stream", self.stream) or kwargs.get("stream_handler", self.stream_handler) is not None
|
||||
stream = kwargs.get("stream", self.stream)
|
||||
stream_handler = kwargs.get("stream_handler", self.stream_handler)
|
||||
stream = stream or stream_handler is not None
|
||||
if kwargs and "prompt" in kwargs:
|
||||
prompt = kwargs.pop("prompt")
|
||||
|
||||
# Consider only Text2TextGenerationPipeline and TextGenerationPipeline relevant, ignore others
|
||||
# For more details refer to Hugging Face Text2TextGenerationPipeline and TextGenerationPipeline
|
||||
# documentation
|
||||
# TODO resolve these kwargs from the pipeline signature
|
||||
model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
@ -227,7 +221,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
model_input_kwargs["max_length"] = self.max_length
|
||||
|
||||
if stream:
|
||||
stream_handler: TokenStreamingHandler = kwargs.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()
|
||||
model_input_kwargs["streamer"] = HFTokenStreamingHandler(self.pipe.tokenizer, stream_handler)
|
||||
|
||||
output = self.pipe(prompt, **model_input_kwargs)
|
||||
@ -248,7 +242,8 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
:param prompt: Prompt text to be sent to the generative model.
|
||||
"""
|
||||
model_max_length = self.pipe.tokenizer.model_max_length
|
||||
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
|
||||
tokenized_prompt = self.pipe.tokenizer.tokenize(prompt)
|
||||
n_prompt_tokens = len(tokenized_prompt)
|
||||
n_answer_tokens = self.max_length
|
||||
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
|
||||
return prompt
|
||||
@ -263,20 +258,38 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
model_max_length,
|
||||
)
|
||||
|
||||
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
|
||||
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
|
||||
tokenized_payload[: model_max_length - n_answer_tokens]
|
||||
tokenized_prompt[: model_max_length - n_answer_tokens]
|
||||
)
|
||||
return decoded_string
|
||||
|
||||
def _extract_torch_dtype(self, **kwargs) -> Optional[torch.dtype]:
|
||||
torch_dtype_resolved = None
|
||||
torch_dtype = kwargs.get("torch_dtype", None)
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if "torch." in torch_dtype:
|
||||
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
|
||||
elif torch_dtype == "auto":
|
||||
torch_dtype_resolved = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
|
||||
)
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype_resolved = torch_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
|
||||
return torch_dtype_resolved
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
task_name: Optional[str] = None
|
||||
task_name: Optional[str] = kwargs.get("task_name", None)
|
||||
if os.path.exists(model_name_or_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
|
||||
task_name = task_name or get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
|
||||
except RuntimeError:
|
||||
# This will fail for all non-HF models
|
||||
return False
|
||||
@ -300,4 +313,5 @@ class StopWordsCriteria(StoppingCriteria):
|
||||
self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return any(torch.isin(input_ids[-1], self.stop_words["input_ids"]))
|
||||
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
|
||||
return any(all(stop_word) for stop_word in stop_result)
|
||||
|
||||
510
test/prompt/invocation_layer/test_hugging_face.py
Normal file
510
test/prompt/invocation_layer/test_hugging_face.py
Normal file
@ -0,0 +1,510 @@
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import device
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BloomForCausalLM, StoppingCriteriaList, GenerationConfig
|
||||
|
||||
from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import StopWordsCriteria
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
# mock transformers pipeline
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline") as mocked_pipeline:
|
||||
mocked_pipeline.return_value = Mock(**{"model_name_or_path": None, "tokenizer.model_max_length": 100})
|
||||
yield mocked_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_task():
|
||||
# mock get_task function
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task") as mock_get_task:
|
||||
mock_get_task.return_value = "text2text-generation"
|
||||
yield mock_get_task
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_invalid_task_name(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with invalid task_name
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Task name custom-text2text-generation is not supported"):
|
||||
HFLocalInvocationLayer("google/flan-t5-base", task_name="custom-text2text-generation")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with model_name_or_path only
|
||||
"""
|
||||
HFLocalInvocationLayer("google/flan-t5-base")
|
||||
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
|
||||
# device is set to cpu by default and device_map is empty
|
||||
assert kwargs["device"] == device("cpu")
|
||||
assert not kwargs["device_map"]
|
||||
|
||||
# correct task and model are set
|
||||
assert kwargs["task"] == "text2text-generation"
|
||||
assert kwargs["model"] == "google/flan-t5-base"
|
||||
|
||||
# no matter what kwargs we pass or don't pass, there are always 13 predefined kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
|
||||
# and these kwargs are passed to the pipeline
|
||||
assert list(kwargs.keys()) == [
|
||||
"task",
|
||||
"model",
|
||||
"config",
|
||||
"tokenizer",
|
||||
"feature_extractor",
|
||||
"revision",
|
||||
"use_auth_token",
|
||||
"device_map",
|
||||
"device",
|
||||
"torch_dtype",
|
||||
"trust_remote_code",
|
||||
"model_kwargs",
|
||||
"pipeline_class",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_name_and_device_map(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with model_name_or_path and device_map
|
||||
"""
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", device="cpu", device_map="auto")
|
||||
|
||||
assert layer.pipe == mock_pipeline.return_value
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
|
||||
# device is NOT set; device_map is auto because device_map takes precedence over device
|
||||
assert not kwargs["device"]
|
||||
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
|
||||
|
||||
# correct task and model are set as well
|
||||
assert kwargs["task"] == "text2text-generation"
|
||||
assert kwargs["model"] == "google/flan-t5-base"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_torch_dtype(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with torch_dtype parameter using the actual torch object
|
||||
"""
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype=torch.float16)
|
||||
|
||||
assert layer.pipe == mock_pipeline.return_value
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_torch_dtype_as_str(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with torch_dtype parameter using the string definition
|
||||
"""
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="torch.float16")
|
||||
|
||||
assert layer.pipe == mock_pipeline.return_value
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_torch_dtype_auto(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with torch_dtype parameter using the auto string definition
|
||||
"""
|
||||
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="auto")
|
||||
|
||||
assert layer.pipe == mock_pipeline.return_value
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
assert kwargs["torch_dtype"] == "auto"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_invalid_torch_dtype(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with invalid torch_dtype parameter
|
||||
"""
|
||||
|
||||
# we need to provide torch_dtype as a string but with torch. prefix
|
||||
# this should raise an error
|
||||
with pytest.raises(ValueError, match="torch_dtype should be a torch.dtype, a string with 'torch.' prefix"):
|
||||
HFLocalInvocationLayer("google/flan-t5-base", torch_dtype="float16")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_invalid_torch_dtype_object(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with invalid parameter
|
||||
"""
|
||||
|
||||
# we need to provide torch_dtype as a string but with torch. prefix
|
||||
# this should raise an error
|
||||
with pytest.raises(ValueError, match="Invalid torch_dtype value {'invalid': 'object'}"):
|
||||
HFLocalInvocationLayer("google/flan-t5-base", torch_dtype={"invalid": "object"})
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_ensure_token_limit_positive():
|
||||
"""
|
||||
Test that ensure_token_limit works as expected, short prompt text is not changed
|
||||
"""
|
||||
prompt_text = "this is a short prompt"
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=10, model_max_length=20)
|
||||
|
||||
processed_prompt_text = layer._ensure_token_limit(prompt_text)
|
||||
assert prompt_text == processed_prompt_text
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_ensure_token_limit_negative(caplog):
|
||||
"""
|
||||
Test that ensure_token_limit chops the prompt text if it's longer than the max length allowed by the model
|
||||
"""
|
||||
prompt_text = "this is a prompt test that is longer than the max length allowed by the model"
|
||||
layer = HFLocalInvocationLayer("google/flan-t5-base", max_length=10, model_max_length=20)
|
||||
|
||||
processed_prompt_text = layer._ensure_token_limit(prompt_text)
|
||||
assert prompt_text != processed_prompt_text
|
||||
assert len(processed_prompt_text.split()) <= len(prompt_text.split())
|
||||
expected_message = (
|
||||
"The prompt has been truncated from 17 tokens to 10 tokens so that the prompt length and "
|
||||
"answer length (10 tokens) fit within the max token limit (20 tokens). Shorten the prompt "
|
||||
"to prevent it from being cut off"
|
||||
)
|
||||
assert caplog.records[0].message == expected_message
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_custom_pretrained_model(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test that the constructor sets the pipeline with the pretrained model (if provided)
|
||||
"""
|
||||
# actual model and tokenizer passed to the pipeline
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
HFLocalInvocationLayer(
|
||||
model_name_or_path="irrelevant_when_model_is_provided",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task_name="text2text-generation",
|
||||
)
|
||||
|
||||
mock_pipeline.assert_called_once()
|
||||
# mock_get_task is not called as we provided task_name parameter
|
||||
mock_get_task.assert_not_called()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
|
||||
# correct tokenizer and model are set as well
|
||||
assert kwargs["tokenizer"] == tokenizer
|
||||
assert kwargs["model"] == model
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_invalid_kwargs(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with invalid kwargs
|
||||
"""
|
||||
|
||||
HFLocalInvocationLayer("google/flan-t5-base", some_invalid_kwarg="invalid")
|
||||
|
||||
mock_pipeline.assert_called_once()
|
||||
mock_get_task.assert_called_once()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
|
||||
# invalid kwargs are ignored and not passed to the pipeline
|
||||
assert "some_invalid_kwarg" not in kwargs
|
||||
|
||||
# still our 13 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_various_kwargs(mock_pipeline, mock_get_task):
|
||||
"""
|
||||
Test HFLocalInvocationLayer init with various kwargs, make sure all of them are passed to the pipeline
|
||||
except for the invalid ones
|
||||
"""
|
||||
|
||||
HFLocalInvocationLayer(
|
||||
"google/flan-t5-base",
|
||||
task_name="text2text-generation",
|
||||
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
|
||||
config=Mock(),
|
||||
revision="1.1",
|
||||
device="cpu",
|
||||
device_map="auto",
|
||||
first_invalid_kwarg="invalid",
|
||||
second_invalid_kwarg="invalid",
|
||||
)
|
||||
|
||||
mock_pipeline.assert_called_once()
|
||||
# mock_get_task is not called as we provided task_name parameter
|
||||
mock_get_task.assert_not_called()
|
||||
|
||||
args, kwargs = mock_pipeline.call_args
|
||||
|
||||
# invalid kwargs are ignored and not passed to the pipeline
|
||||
assert "first_invalid_kwarg" not in kwargs
|
||||
assert "second_invalid_kwarg" not in kwargs
|
||||
|
||||
# correct task and model are set as well
|
||||
assert kwargs["task"] == "text2text-generation"
|
||||
assert not kwargs["device"]
|
||||
assert kwargs["device_map"] and kwargs["device_map"] == "auto"
|
||||
assert kwargs["revision"] == "1.1"
|
||||
|
||||
# still on 13 kwargs passed to the pipeline
|
||||
assert len(kwargs) == 13
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_text_generation_model():
|
||||
# test simple prompting with text generation model
|
||||
# by default, we force the model not return prompt text
|
||||
# Thus text-generation models can be used with PromptNode
|
||||
# just like text2text-generation models
|
||||
layer = HFLocalInvocationLayer("bigscience/bigscience-small-testing")
|
||||
r = layer.invoke(prompt="Hello big science!")
|
||||
assert len(r[0]) > 0
|
||||
|
||||
# test prompting with parameter to return prompt text as well
|
||||
# users can use this param to get the prompt text and the generated text
|
||||
r = layer.invoke(prompt="Hello big science!", return_full_text=True)
|
||||
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_text_generation_model_via_custom_pretrained_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bigscience-small-testing")
|
||||
model = BloomForCausalLM.from_pretrained("bigscience/bigscience-small-testing")
|
||||
layer = HFLocalInvocationLayer(
|
||||
"irrelevant_when_model_is_provided", model=model, tokenizer=tokenizer, task_name="text-generation"
|
||||
)
|
||||
r = layer.invoke(prompt="Hello big science")
|
||||
assert len(r[0]) > 0
|
||||
|
||||
# test prompting with parameter to return prompt text as well
|
||||
# users can use this param to get the prompt text and the generated text
|
||||
r = layer.invoke(prompt="Hello big science", return_full_text=True)
|
||||
assert len(r[0]) > 0 and r[0].startswith("Hello big science")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_param_in_constructor():
|
||||
"""
|
||||
Test stream parameter is correctly passed to pipeline invocation via HF streamer parameter
|
||||
"""
|
||||
layer = HFLocalInvocationLayer(stream=True)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
assert "streamer" in kwargs and isinstance(kwargs["streamer"], HFTokenStreamingHandler)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_stream_handler_param_in_constructor():
|
||||
"""
|
||||
Test stream parameter is correctly passed to pipeline invocation
|
||||
"""
|
||||
dtsh = DefaultTokenStreamingHandler()
|
||||
layer = HFLocalInvocationLayer(stream_handler=dtsh)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
assert "streamer" in kwargs
|
||||
hf_streamer = kwargs["streamer"]
|
||||
|
||||
# we wrap our TokenStreamingHandler with HFTokenStreamingHandler
|
||||
assert isinstance(hf_streamer, HFTokenStreamingHandler)
|
||||
|
||||
# but under the hood, the wrapped handler is DefaultTokenStreamingHandler we passed
|
||||
assert isinstance(hf_streamer.token_handler, DefaultTokenStreamingHandler)
|
||||
assert hf_streamer.token_handler == dtsh
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports(tmp_path):
|
||||
"""
|
||||
Test that supports returns True correctly for HFLocalInvocationLayer
|
||||
"""
|
||||
# mock get_task to avoid remote calls to HF hub
|
||||
mock_get_task = Mock(return_value="text2text-generation")
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
|
||||
assert HFLocalInvocationLayer.supports("google/flan-t5-base")
|
||||
assert HFLocalInvocationLayer.supports("mosaicml/mpt-7b")
|
||||
assert HFLocalInvocationLayer.supports("CarperAI/stable-vicuna-13b-delta")
|
||||
assert mock_get_task.call_count == 3
|
||||
|
||||
# some HF local model directory, let's use the one from test/prompt/invocation_layer
|
||||
assert HFLocalInvocationLayer.supports(str(tmp_path))
|
||||
|
||||
# but not some non text2text-generation or non text-generation model
|
||||
# i.e image classification model
|
||||
mock_get_task = Mock(return_value="image-classification")
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
|
||||
assert not HFLocalInvocationLayer.supports("nateraw/vit-age-classifier")
|
||||
assert mock_get_task.call_count == 1
|
||||
|
||||
# or some POS tagging model
|
||||
mock_get_task = Mock(return_value="pos-tagging")
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
|
||||
assert not HFLocalInvocationLayer.supports("vblagoje/bert-english-uncased-finetuned-pos")
|
||||
assert mock_get_task.call_count == 1
|
||||
|
||||
# unless we specify the task name to override the default
|
||||
# short-circuit the get_task call
|
||||
assert HFLocalInvocationLayer.supports(
|
||||
"vblagoje/bert-english-uncased-finetuned-pos", task_name="text2text-generation"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stop_words_criteria_set():
|
||||
"""
|
||||
Test that stop words criteria is correctly set in pipeline invocation
|
||||
"""
|
||||
layer = HFLocalInvocationLayer(
|
||||
model_name_or_path="hf-internal-testing/tiny-random-t5", task_name="text2text-generation"
|
||||
)
|
||||
layer.pipe = MagicMock()
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=["hello", "world"])
|
||||
|
||||
args, kwargs = layer.pipe.call_args
|
||||
assert "stopping_criteria" in kwargs
|
||||
assert isinstance(kwargs["stopping_criteria"], StoppingCriteriaList)
|
||||
assert len(kwargs["stopping_criteria"]) == 1
|
||||
assert isinstance(kwargs["stopping_criteria"][0], StopWordsCriteria)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"], ["hello", "good", "health"]])
|
||||
def test_stop_words_single_token(stop_words):
|
||||
"""
|
||||
Test that stop words criteria is used and that it works with single token stop words
|
||||
"""
|
||||
|
||||
# simple test with words not broken down into multiple tokens
|
||||
default_model = "google/flan-t5-base"
|
||||
tokenizer = AutoTokenizer.from_pretrained(default_model)
|
||||
# each word is broken down into a single token
|
||||
tokens = tokenizer.tokenize("good health wish")
|
||||
assert len(tokens) == 3
|
||||
|
||||
layer = HFLocalInvocationLayer(model_name_or_path=default_model)
|
||||
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words)
|
||||
assert len(result) > 0
|
||||
assert result[0].startswith("I wish you a")
|
||||
assert "good" not in result[0]
|
||||
assert "health" not in result[0]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_stop_words_multiple_token():
|
||||
"""
|
||||
Test that stop words criteria is used and that it works for multi-token words
|
||||
"""
|
||||
# complex test with words broken down into multiple tokens
|
||||
default_model = "google/flan-t5-base"
|
||||
tokenizer = AutoTokenizer.from_pretrained(default_model)
|
||||
# single word unambiguously is broken down into 3 tokens
|
||||
tokens = tokenizer.tokenize("unambiguously")
|
||||
assert len(tokens) == 3
|
||||
|
||||
layer = HFLocalInvocationLayer(model_name_or_path=default_model)
|
||||
result = layer.invoke(
|
||||
prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=["unambiguously"]
|
||||
)
|
||||
# yet the stop word is correctly stopped on and removed
|
||||
assert len(result) > 0
|
||||
assert result[0].startswith("I wish you")
|
||||
assert "unambiguously" not in result[0]
|
||||
assert "good" not in result[0]
|
||||
assert "health" not in result[0]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_stop_words_not_being_found():
|
||||
# simple test with words not broken down into multiple tokens
|
||||
layer = HFLocalInvocationLayer()
|
||||
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=["Berlin"])
|
||||
assert len(result) > 0
|
||||
for word in "I wish you a good health".split():
|
||||
assert word in result[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generation_kwargs_from_constructor():
|
||||
"""
|
||||
Test that generation_kwargs are correctly passed to pipeline invocation from constructor
|
||||
"""
|
||||
the_question = "What does 42 mean?"
|
||||
# test that generation_kwargs are passed to the underlying HF model
|
||||
layer = HFLocalInvocationLayer(generation_kwargs={"do_sample": True})
|
||||
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
layer.invoke(prompt=the_question)
|
||||
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "max_length": 100}, {})
|
||||
|
||||
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
|
||||
layer = HFLocalInvocationLayer(generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
|
||||
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
layer.invoke(prompt=the_question)
|
||||
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generation_kwargs_from_invoke():
|
||||
"""
|
||||
Test that generation_kwargs passed to invoke are passed to the underlying HF model
|
||||
"""
|
||||
the_question = "What does 42 mean?"
|
||||
# test that generation_kwargs are passed to the underlying HF model
|
||||
layer = HFLocalInvocationLayer()
|
||||
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
layer.invoke(prompt=the_question, generation_kwargs={"do_sample": True})
|
||||
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "max_length": 100}, {})
|
||||
|
||||
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
|
||||
layer = HFLocalInvocationLayer()
|
||||
with patch.object(layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
layer.invoke(prompt=the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
|
||||
|
||||
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})
|
||||
@ -36,28 +36,3 @@ def test_construtor_with_custom_model():
|
||||
def test_constructor_with_no_supported_model():
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptModel("some-random-model")
|
||||
|
||||
|
||||
def create_mock_pipeline(model_name_or_path=None, max_length=100):
|
||||
return Mock(
|
||||
**{"model_name_or_path": model_name_or_path},
|
||||
return_value=Mock(**{"model_name_or_path": model_name_or_path, "tokenizer.model_max_length": max_length}),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_hf_local_invocation_layer_with_task_name():
|
||||
mock_pipeline = create_mock_pipeline()
|
||||
mock_get_task = Mock(return_value="dummy_task")
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.get_task", mock_get_task):
|
||||
with patch("haystack.nodes.prompt.invocation_layer.hugging_face.pipeline", mock_pipeline):
|
||||
PromptModel(
|
||||
model_name_or_path="local_model",
|
||||
max_length=100,
|
||||
model_kwargs={"task_name": "dummy_task"},
|
||||
invocation_layer_class=HFLocalInvocationLayer,
|
||||
)
|
||||
# checking if get_task is called when task_name is passed to HFLocalInvocationLayer constructor
|
||||
mock_get_task.assert_not_called()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
@ -196,217 +196,6 @@ def test_invalid_template_params(mock_model, mock_prompthub):
|
||||
node.prompt("question-answering-per-document", some_crazy_key="Berlin is the capital of Germany.")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generation_kwargs_from_prompt_node_init():
|
||||
the_question = "What does 42 mean?"
|
||||
# test that generation_kwargs are passed to the underlying HF model
|
||||
node = PromptNode(model_kwargs={"generation_kwargs": {"do_sample": True}})
|
||||
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
node(the_question)
|
||||
mock_call.assert_called_with(
|
||||
the_question, {}, {"do_sample": True, "num_return_sequences": 1, "num_beams": 1, "max_length": 100}, {}
|
||||
)
|
||||
|
||||
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
|
||||
node = PromptNode(model_kwargs={"generation_kwargs": GenerationConfig(do_sample=True, top_p=0.9)})
|
||||
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
node(the_question)
|
||||
mock_call.assert_called_with(
|
||||
the_question,
|
||||
{},
|
||||
{"do_sample": True, "top_p": 0.9, "num_return_sequences": 1, "num_beams": 1, "max_length": 100},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generation_kwargs_from_prompt_node_call():
|
||||
the_question = "What does 42 mean?"
|
||||
# default node with local HF model
|
||||
node = PromptNode()
|
||||
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
# test that generation_kwargs are passed to the underlying HF model
|
||||
node(the_question, generation_kwargs={"do_sample": True})
|
||||
mock_call.assert_called_with(
|
||||
the_question, {}, {"do_sample": True, "num_return_sequences": 1, "num_beams": 1, "max_length": 100}, {}
|
||||
)
|
||||
|
||||
# default node with local HF model
|
||||
node = PromptNode()
|
||||
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
# test that generation_kwargs in the form of GenerationConfig are passed to the underlying HF model
|
||||
node(the_question, generation_kwargs=GenerationConfig(do_sample=True, top_p=0.9))
|
||||
mock_call.assert_called_with(
|
||||
the_question,
|
||||
{},
|
||||
{"do_sample": True, "top_p": 0.9, "num_return_sequences": 1, "num_beams": 1, "max_length": 100},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generation_kwargs_are_passed_to_prompt_model_invoke():
|
||||
prompt = "What does 42 mean?"
|
||||
|
||||
mock_prompt_model = Mock()
|
||||
# This is set so this mock can pass isinstance() checks run in
|
||||
# PromptNode.__init__
|
||||
mock_prompt_model.__class__ = PromptModel
|
||||
mock_prompt_model.invoke.return_value = []
|
||||
mock_prompt_model._ensure_token_limit.return_value = prompt
|
||||
|
||||
# Create PromptNode using the mocked PromptModel
|
||||
node = PromptNode(model_name_or_path=mock_prompt_model)
|
||||
node.run(query=prompt, prompt_template="{query}", generation_kwargs={"do_sample": True})
|
||||
|
||||
# Verify the do_sample keyword argument is passed to PromptModel.invoke()
|
||||
mock_prompt_model.invoke.assert_called_once_with(
|
||||
prompt, stop_words=None, top_k=1, query="What does 42 mean?", do_sample=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_stop_words(prompt_model):
|
||||
# TODO: This can be a unit test for StopWordCriteria
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
# test single stop word for both HF and OpenAI
|
||||
# set stop words in PromptNode
|
||||
node = PromptNode(prompt_model, stop_words=["capital"])
|
||||
|
||||
# with default prompt template and stop words set in PN
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
)
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
|
||||
# test stop words for both HF and OpenAI
|
||||
# set stop words in PromptNode
|
||||
node = PromptNode(prompt_model, stop_words=["capital", "Germany"])
|
||||
|
||||
# with default prompt template and stop words set in PN
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
)
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
|
||||
# with default prompt template and stop words set in kwargs (overrides PN stop words)
|
||||
r = node.prompt(
|
||||
"Given the context please generate a question.\nContext: {documents};\nQuestion:",
|
||||
documents=["Berlin is the capital of Germany."],
|
||||
stop_words=None,
|
||||
)
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
# simple prompting
|
||||
r = node("Given the context please generate a question. Context: Berlin is the capital of Germany.; Question:")
|
||||
assert len(r[0]) > 0
|
||||
assert "capital" not in r[0]
|
||||
assert "Germany" not in r[0]
|
||||
|
||||
# simple prompting with stop words set in kwargs (overrides PN stop words)
|
||||
r = node(
|
||||
"Given the context please generate a question. Context: Berlin is the capital of Germany.; Question:",
|
||||
stop_words=None,
|
||||
)
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
tt = PromptTemplate("Given the context please generate a question. Context: {documents}; Question:")
|
||||
# with custom prompt template
|
||||
r = node.prompt(tt, documents=["Berlin is the capital of Germany."])
|
||||
assert r[0] == "What is the" or r[0] == "What city is the"
|
||||
|
||||
# with custom prompt template and stop words set in kwargs (overrides PN stop words)
|
||||
r = node.prompt(tt, documents=["Berlin is the capital of Germany."], stop_words=None)
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_prompt_node_model_max_length(caplog):
|
||||
prompt = "This is a prompt " * 5 # (26 tokens with t5 flan tokenizer)
|
||||
|
||||
# test that model_max_length is set to 1024
|
||||
# test that model doesn't truncate the prompt if it is shorter than
|
||||
# the model max length minus the length of the output
|
||||
# no warning is raised
|
||||
node = PromptNode(model_kwargs={"model_max_length": 1024})
|
||||
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 1024
|
||||
with caplog.at_level(logging.WARNING):
|
||||
node.prompt(prompt)
|
||||
assert len(caplog.text) <= 0
|
||||
|
||||
# test that model_max_length is set to 10
|
||||
# test that model truncates the prompt if it is longer than the max length (10 tokens)
|
||||
# a warning is raised
|
||||
node = PromptNode(model_kwargs={"model_max_length": 10})
|
||||
assert node.prompt_model.model_invocation_layer.pipe.tokenizer.model_max_length == 10
|
||||
with caplog.at_level(logging.WARNING):
|
||||
node.prompt(prompt)
|
||||
assert "The prompt has been truncated from 26 tokens to 0 tokens" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_streaming_handler_on_call(mock_model):
|
||||
"""
|
||||
Verifies model is created using expected stream handler when calling PromptNode.
|
||||
"""
|
||||
mock_handler = Mock()
|
||||
node = PromptNode()
|
||||
node.prompt_model = mock_model
|
||||
node("Irrelevant prompt", stream=True, stream_handler=mock_handler)
|
||||
# Verify model has been constructed with expected model_kwargs
|
||||
mock_model.invoke.assert_called_once()
|
||||
assert mock_model.invoke.call_args_list[0].kwargs["stream_handler"] == mock_handler
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_node_hf_model_streaming():
|
||||
# tests that HF streaming handler is passed to the HF pipeline run_single method as "streamer" kwarg with
|
||||
# the required HF type transformers.generation.streamers.TextStreamer
|
||||
|
||||
pn = PromptNode(model_kwargs={"stream_handler": DefaultTokenStreamingHandler()})
|
||||
with patch.object(pn.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
|
||||
pn("Irrelevant prompt")
|
||||
args, kwargs = mock_call.call_args
|
||||
assert "streamer" in args[2]
|
||||
assert isinstance(args[2]["streamer"], TextStreamer)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_prompt_node_streaming_handler_on_constructor(mock_model):
|
||||
"""
|
||||
Verifies model is created using expected stream handler when constructing PromptNode.
|
||||
"""
|
||||
model_kwargs = {"stream_handler": Mock()}
|
||||
PromptNode(model_kwargs=model_kwargs)
|
||||
# Verify model has been constructed with expected model_kwargs
|
||||
mock_model.assert_called_once()
|
||||
assert mock_model.call_args_list[0].kwargs["model_kwargs"] == model_kwargs
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_prompt_node_with_text_generation_model():
|
||||
# TODO: This is an integration test for HFLocalInvocationLayer
|
||||
# test simple prompting with text generation model
|
||||
# by default, we force the model not return prompt text
|
||||
# Thus text-generation models can be used with PromptNode
|
||||
# just like text2text-generation models
|
||||
node = PromptNode("bigscience/bigscience-small-testing")
|
||||
r = node("Hello big science!")
|
||||
assert len(r[0]) > 0
|
||||
|
||||
# test prompting with parameter to return prompt text as well
|
||||
# users can use this param to get the prompt text and the generated text
|
||||
r = node("Hello big science!", return_full_text=True)
|
||||
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
@ -1194,14 +983,6 @@ class TestRunBatch:
|
||||
assert isinstance(result["results"][0][0], str)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_HFLocalInvocationLayer_supports():
|
||||
# TODO: HFLocalInvocationLayer test, to be moved
|
||||
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_chatgpt_direct_prompting(chatgpt_prompt_model):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user