mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 02:32:32 +00:00
feat: Update allowed models to be used with Prompt Node (#4018)
* Update allowed models to be used with Prompt Node * Added try except block around the config to skip over OpenAI models. * Fixing tests * Adding warning message * Adding test for different HF models that could be used in prompt node
This commit is contained in:
parent
8135e75139
commit
01d39df863
@ -11,12 +11,13 @@ import requests
|
||||
import torch
|
||||
from transformers import (
|
||||
pipeline,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoConfig,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
PreTrainedTokenizerFast,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
|
||||
from haystack import MultiLabel
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||
@ -265,8 +266,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||
For more details about these kwargs, see
|
||||
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||
|
||||
|
||||
"""
|
||||
super().__init__(model_name_or_path, max_length)
|
||||
self.use_auth_token = use_auth_token
|
||||
@ -365,15 +364,22 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str) -> bool:
|
||||
if not all(m in model_name_or_path for m in ["google", "flan", "t5"]):
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
except OSError:
|
||||
# This is needed so OpenAI models are skipped over
|
||||
return False
|
||||
|
||||
try:
|
||||
# if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly
|
||||
AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
||||
except EnvironmentError:
|
||||
return False
|
||||
return True
|
||||
if not all(m in model_name_or_path for m in ["flan", "t5"]):
|
||||
logger.warning(
|
||||
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
|
||||
"Many of the default prompts and PromptTemplates will likely not work as intended. "
|
||||
"Please use custom prompts and PromptTemplates specific to the %s model",
|
||||
model_name_or_path,
|
||||
)
|
||||
|
||||
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
|
||||
return config.architectures[0] in supported_models
|
||||
|
||||
|
||||
class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
@ -578,8 +584,8 @@ class PromptModel(BaseComponent):
|
||||
)
|
||||
raise ValueError(
|
||||
f"Model {self.model_name_or_path} is not supported - no invocation layer found."
|
||||
f"Currently supported models are: {self.invocation_layers}"
|
||||
f"Register new invocation layer for {self.model_name_or_path} using the register method."
|
||||
f" Currently supported models are: {self.invocation_layers}"
|
||||
f" Register a new invocation layer for {self.model_name_or_path} using the register method."
|
||||
)
|
||||
|
||||
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
||||
@ -694,7 +700,6 @@ class PromptNode(BaseComponent):
|
||||
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
|
||||
|
||||
For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node).
|
||||
|
||||
"""
|
||||
|
||||
outgoing_edges: int = 1
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer
|
||||
|
||||
|
||||
def is_openai_api_key_set(api_key: str):
|
||||
@ -109,16 +110,6 @@ def test_create_prompt_node():
|
||||
assert prompt_node.model_name_or_path == "text-davinci-003"
|
||||
assert prompt_node.prompt_model is not None
|
||||
|
||||
with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"):
|
||||
# yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
||||
# currently support only T5-Flan models
|
||||
prompt_node = PromptNode("vblagoje/bart_lfqa")
|
||||
|
||||
with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"):
|
||||
# yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
||||
# currently support only T5-Flan models
|
||||
prompt_node = PromptNode("valhalla/t5-base-e2e-qg")
|
||||
|
||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||
PromptNode("some-random-model")
|
||||
|
||||
@ -713,3 +704,8 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert pipeline is not None
|
||||
|
||||
|
||||
def test_HFLocalInvocationLayer_supports():
|
||||
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||
|
Loading…
x
Reference in New Issue
Block a user