feat: Update allowed models to be used with Prompt Node (#4018)

* Update allowed models to be used with Prompt Node

* Added try except block around the config to skip over OpenAI models.

* Fixing tests

* Adding warning message

* Adding test for different HF models that could be used in prompt node
This commit is contained in:
Sebastian 2023-02-08 12:47:52 +01:00 committed by GitHub
parent 8135e75139
commit 01d39df863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 24 deletions

View File

@ -11,12 +11,13 @@ import requests
import torch import torch
from transformers import ( from transformers import (
pipeline, pipeline,
AutoModelForSeq2SeqLM, AutoConfig,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
PreTrainedTokenizerFast,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast,
) )
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
from haystack import MultiLabel from haystack import MultiLabel
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
@ -265,8 +266,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see For more details about these kwargs, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline). Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
""" """
super().__init__(model_name_or_path, max_length) super().__init__(model_name_or_path, max_length)
self.use_auth_token = use_auth_token self.use_auth_token = use_auth_token
@ -365,15 +364,22 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
@classmethod @classmethod
def supports(cls, model_name_or_path: str) -> bool: def supports(cls, model_name_or_path: str) -> bool:
if not all(m in model_name_or_path for m in ["google", "flan", "t5"]): try:
config = AutoConfig.from_pretrained(model_name_or_path)
except OSError:
# This is needed so OpenAI models are skipped over
return False return False
try: if not all(m in model_name_or_path for m in ["flan", "t5"]):
# if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly logger.warning(
AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) "PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
except EnvironmentError: "Many of the default prompts and PromptTemplates will likely not work as intended. "
return False "Please use custom prompts and PromptTemplates specific to the %s model",
return True model_name_or_path,
)
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
return config.architectures[0] in supported_models
class OpenAIInvocationLayer(PromptModelInvocationLayer): class OpenAIInvocationLayer(PromptModelInvocationLayer):
@ -578,8 +584,8 @@ class PromptModel(BaseComponent):
) )
raise ValueError( raise ValueError(
f"Model {self.model_name_or_path} is not supported - no invocation layer found." f"Model {self.model_name_or_path} is not supported - no invocation layer found."
f"Currently supported models are: {self.invocation_layers}" f" Currently supported models are: {self.invocation_layers}"
f"Register new invocation layer for {self.model_name_or_path} using the register method." f" Register a new invocation layer for {self.model_name_or_path} using the register method."
) )
def register(self, invocation_layer: Type[PromptModelInvocationLayer]): def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
@ -694,7 +700,6 @@ class PromptNode(BaseComponent):
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models. LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node). For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node).
""" """
outgoing_edges: int = 1 outgoing_edges: int = 1

View File

@ -8,6 +8,7 @@ import torch
from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer
def is_openai_api_key_set(api_key: str): def is_openai_api_key_set(api_key: str):
@ -109,16 +110,6 @@ def test_create_prompt_node():
assert prompt_node.model_name_or_path == "text-davinci-003" assert prompt_node.model_name_or_path == "text-davinci-003"
assert prompt_node.prompt_model is not None assert prompt_node.prompt_model is not None
with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"):
# yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
# currently support only T5-Flan models
prompt_node = PromptNode("vblagoje/bart_lfqa")
with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"):
# yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
# currently support only T5-Flan models
prompt_node = PromptNode("valhalla/t5-base-e2e-qg")
with pytest.raises(ValueError, match="Model some-random-model is not supported"): with pytest.raises(ValueError, match="Model some-random-model is not supported"):
PromptNode("some-random-model") PromptNode("some-random-model")
@ -713,3 +704,8 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
) )
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipeline is not None assert pipeline is not None
def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")