mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 02:32:32 +00:00
feat: Update allowed models to be used with Prompt Node (#4018)
* Update allowed models to be used with Prompt Node * Added try except block around the config to skip over OpenAI models. * Fixing tests * Adding warning message * Adding test for different HF models that could be used in prompt node
This commit is contained in:
parent
8135e75139
commit
01d39df863
@ -11,12 +11,13 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
pipeline,
|
pipeline,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoConfig,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
)
|
)
|
||||||
|
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
from haystack import MultiLabel
|
from haystack import MultiLabel
|
||||||
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||||
@ -265,8 +266,6 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
|||||||
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
|
||||||
For more details about these kwargs, see
|
For more details about these kwargs, see
|
||||||
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(model_name_or_path, max_length)
|
super().__init__(model_name_or_path, max_length)
|
||||||
self.use_auth_token = use_auth_token
|
self.use_auth_token = use_auth_token
|
||||||
@ -365,15 +364,22 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports(cls, model_name_or_path: str) -> bool:
|
def supports(cls, model_name_or_path: str) -> bool:
|
||||||
if not all(m in model_name_or_path for m in ["google", "flan", "t5"]):
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
except OSError:
|
||||||
|
# This is needed so OpenAI models are skipped over
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
if not all(m in model_name_or_path for m in ["flan", "t5"]):
|
||||||
# if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly
|
logger.warning(
|
||||||
AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
|
||||||
except EnvironmentError:
|
"Many of the default prompts and PromptTemplates will likely not work as intended. "
|
||||||
return False
|
"Please use custom prompts and PromptTemplates specific to the %s model",
|
||||||
return True
|
model_name_or_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
|
||||||
|
return config.architectures[0] in supported_models
|
||||||
|
|
||||||
|
|
||||||
class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||||
@ -578,8 +584,8 @@ class PromptModel(BaseComponent):
|
|||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {self.model_name_or_path} is not supported - no invocation layer found."
|
f"Model {self.model_name_or_path} is not supported - no invocation layer found."
|
||||||
f"Currently supported models are: {self.invocation_layers}"
|
f" Currently supported models are: {self.invocation_layers}"
|
||||||
f"Register new invocation layer for {self.model_name_or_path} using the register method."
|
f" Register a new invocation layer for {self.model_name_or_path} using the register method."
|
||||||
)
|
)
|
||||||
|
|
||||||
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
||||||
@ -694,7 +700,6 @@ class PromptNode(BaseComponent):
|
|||||||
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
|
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
|
||||||
|
|
||||||
For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node).
|
For more details, see the PromptNode [documentation](https://docs.haystack.deepset.ai/docs/prompt_node).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
outgoing_edges: int = 1
|
outgoing_edges: int = 1
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||||
from haystack.errors import OpenAIError
|
from haystack.errors import OpenAIError
|
||||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||||
|
from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer
|
||||||
|
|
||||||
|
|
||||||
def is_openai_api_key_set(api_key: str):
|
def is_openai_api_key_set(api_key: str):
|
||||||
@ -109,16 +110,6 @@ def test_create_prompt_node():
|
|||||||
assert prompt_node.model_name_or_path == "text-davinci-003"
|
assert prompt_node.model_name_or_path == "text-davinci-003"
|
||||||
assert prompt_node.prompt_model is not None
|
assert prompt_node.prompt_model is not None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"):
|
|
||||||
# yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
|
||||||
# currently support only T5-Flan models
|
|
||||||
prompt_node = PromptNode("vblagoje/bart_lfqa")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"):
|
|
||||||
# yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting
|
|
||||||
# currently support only T5-Flan models
|
|
||||||
prompt_node = PromptNode("valhalla/t5-base-e2e-qg")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
with pytest.raises(ValueError, match="Model some-random-model is not supported"):
|
||||||
PromptNode("some-random-model")
|
PromptNode("some-random-model")
|
||||||
|
|
||||||
@ -713,3 +704,8 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
|
|||||||
)
|
)
|
||||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
assert pipeline is not None
|
assert pipeline is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_HFLocalInvocationLayer_supports():
|
||||||
|
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
|
||||||
|
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user