diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index f70a347c6..0bc2eb2e4 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -11,12 +11,13 @@ import requests import torch from transformers import ( pipeline, - AutoModelForSeq2SeqLM, + AutoConfig, StoppingCriteria, StoppingCriteriaList, - PreTrainedTokenizerFast, PreTrainedTokenizer, + PreTrainedTokenizerFast, ) +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES from haystack import MultiLabel from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES @@ -265,8 +266,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map. For more details about these kwargs, see Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline). - - """ super().__init__(model_name_or_path, max_length) self.use_auth_token = use_auth_token @@ -365,15 +364,22 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): @classmethod def supports(cls, model_name_or_path: str) -> bool: - if not all(m in model_name_or_path for m in ["google", "flan", "t5"]): + try: + config = AutoConfig.from_pretrained(model_name_or_path) + except OSError: + # This is needed so OpenAI models are skipped over return False - try: - # if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly - AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) - except EnvironmentError: - return False - return True + if not all(m in model_name_or_path for m in ["flan", "t5"]): + logger.warning( + "PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. " + "Many of the default prompts and PromptTemplates will likely not work as intended. " + "Please use custom prompts and PromptTemplates specific to the %s model", + model_name_or_path, + ) + + supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values()) + return config.architectures[0] in supported_models class OpenAIInvocationLayer(PromptModelInvocationLayer): @@ -578,8 +584,8 @@ class PromptModel(BaseComponent): ) raise ValueError( f"Model {self.model_name_or_path} is not supported - no invocation layer found." - f"Currently supported models are: {self.invocation_layers}" - f"Register new invocation layer for {self.model_name_or_path} using the register method." + f" Currently supported models are: {self.invocation_layers}" + f" Register a new invocation layer for {self.model_name_or_path} using the register method." ) def register(self, invocation_layer: Type[PromptModelInvocationLayer]): @@ -694,7 +700,6 @@ class PromptNode(BaseComponent): LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models. For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node). - """ outgoing_edges: int = 1 diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py index f081f2adb..03fa29e45 100644 --- a/test/nodes/test_prompt_node.py +++ b/test/nodes/test_prompt_node.py @@ -8,6 +8,7 @@ import torch from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack.errors import OpenAIError from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel +from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer def is_openai_api_key_set(api_key: str): @@ -109,16 +110,6 @@ def test_create_prompt_node(): assert prompt_node.model_name_or_path == "text-davinci-003" assert prompt_node.prompt_model is not None - with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"): - # yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting - # currently support only T5-Flan models - prompt_node = PromptNode("vblagoje/bart_lfqa") - - with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"): - # yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting - # currently support only T5-Flan models - prompt_node = PromptNode("valhalla/t5-base-e2e-qg") - with pytest.raises(ValueError, match="Model some-random-model is not supported"): PromptNode("some-random-model") @@ -713,3 +704,8 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") assert pipeline is not None + + +def test_HFLocalInvocationLayer_supports(): + assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum") + assert HFLocalInvocationLayer.supports("bigscience/T0_3B")