feat: Update allowed models to be used with Prompt Node (#4018)

* Update allowed models to be used with Prompt Node

* Added try except block around the config to skip over OpenAI models.

* Fixing tests

* Adding warning message

* Adding test for different HF models that could be used in prompt node
This commit is contained in:
Sebastian 2023-02-08 12:47:52 +01:00 committed by GitHub
parent 8135e75139
commit 01d39df863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 24 deletions

View File

@ -11,12 +11,13 @@ import requests
import torch
from transformers import (
pipeline,
AutoModelForSeq2SeqLM,
AutoConfig,
StoppingCriteria,
StoppingCriteriaList,
PreTrainedTokenizerFast,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
from haystack import MultiLabel
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
@ -265,8 +266,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
"""
super().__init__(model_name_or_path, max_length)
self.use_auth_token = use_auth_token
@ -365,15 +364,22 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
@classmethod
def supports(cls, model_name_or_path: str) -> bool:
if not all(m in model_name_or_path for m in ["google", "flan", "t5"]):
try:
config = AutoConfig.from_pretrained(model_name_or_path)
except OSError:
# This is needed so OpenAI models are skipped over
return False
try:
# if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly
AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
except EnvironmentError:
return False
return True
if not all(m in model_name_or_path for m in ["flan", "t5"]):
logger.warning(
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
"Many of the default prompts and PromptTemplates will likely not work as intended. "
"Please use custom prompts and PromptTemplates specific to the %s model",
model_name_or_path,
)
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
return config.architectures[0] in supported_models
class OpenAIInvocationLayer(PromptModelInvocationLayer):
@ -578,8 +584,8 @@ class PromptModel(BaseComponent):
)
raise ValueError(
f"Model {self.model_name_or_path} is not supported - no invocation layer found."
f"Currently supported models are: {self.invocation_layers}"
f"Register new invocation layer for {self.model_name_or_path} using the register method."
f" Currently supported models are: {self.invocation_layers}"
f" Register a new invocation layer for {self.model_name_or_path} using the register method."
)
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
@ -694,7 +700,6 @@ class PromptNode(BaseComponent):
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node).
"""
outgoing_edges: int = 1

View File

@ -8,6 +8,7 @@ import torch
from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer
def is_openai_api_key_set(api_key: str):
@ -109,16 +110,6 @@ def test_create_prompt_node():
assert prompt_node.model_name_or_path == "text-davinci-003"
assert prompt_node.prompt_model is not None
with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"):
# yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
# currently support only T5-Flan models
prompt_node = PromptNode("vblagoje/bart_lfqa")
with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"):
# yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
# currently support only T5-Flan models
prompt_node = PromptNode("valhalla/t5-base-e2e-qg")
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptNode("some-random-model")
@ -713,3 +704,8 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipeline is not None
def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")