refactor: Separate PromptModelInvocationLayers in providers.py (#4327)

* Refactor PromptNode, separate PromptModelInvocationLayers in providers.py
This commit is contained in:
Vladimir Blagojevic 2023-03-06 16:34:59 +01:00 committed by GitHub
parent 1548c5ba0f
commit 348e7d2dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 544 additions and 482 deletions

View File

@ -25,7 +25,7 @@ from haystack.nodes.image_to_text import TransformersImageToText
from haystack.nodes.label_generator import PseudoLabelGenerator from haystack.nodes.label_generator import PseudoLabelGenerator
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger, Shaper from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger, Shaper
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel, PromptModelInvocationLayer
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
from haystack.nodes.question_generator import QuestionGenerator from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker

View File

@ -1 +1,2 @@
from haystack.nodes.prompt.prompt_node import PromptNode, PromptTemplate, PromptModel from haystack.nodes.prompt.prompt_node import PromptNode, PromptTemplate, PromptModel
from haystack.nodes.prompt.providers import PromptModelInvocationLayer

View File

@ -1,34 +1,17 @@
import copy import copy
import logging import logging
import pydoc
import re import re
from abc import ABC, abstractmethod from abc import ABC
from string import Template from string import Template
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator
import torch import torch
from transformers import (
pipeline,
AutoConfig,
StoppingCriteria,
StoppingCriteriaList,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
from haystack import MultiLabel from haystack import MultiLabel
from haystack.errors import OpenAIError
from haystack.modeling.utils import initialize_device_settings
from haystack.nodes.base import BaseComponent from haystack.nodes.base import BaseComponent
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers
from haystack.schema import Document from haystack.schema import Document
from haystack.utils.openai_utils import (
USE_TIKTOKEN,
openai_request,
_openai_text_completion_tokenization_details,
load_openai_tokenizer,
_check_openai_text_completion_answers,
count_openai_tokens,
)
from haystack.telemetry_2 import send_event from haystack.telemetry_2 import send_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -178,456 +161,6 @@ class PromptTemplate(BasePromptTemplate, ABC):
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})" return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"
class PromptModelInvocationLayer:
"""
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
The implementation can be a simple invocation on the underlying model running in a local runtime, or
could be even remote, for example, a call to a remote API endpoint.
"""
def __init__(self, model_name_or_path: str, **kwargs):
"""
Creates a new PromptModelInvocationLayer instance.
:param model_name_or_path: The name or path of the underlying model.
:param kwargs: Additional keyword arguments passed to the underlying model.
"""
if model_name_or_path is None or len(model_name_or_path) == 0:
raise ValueError("model_name_or_path cannot be None or empty string")
self.model_name_or_path = model_name_or_path
@abstractmethod
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated text using the underlying model.
:return: A list of generated text.
"""
pass
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Checks if the given model is supported by this invocation layer.
:param model_name_or_path: The name or path of the model.
:param kwargs: additional keyword arguments passed to the underlying model which might be used to determine
if the model is supported.
:return: True if this invocation layer supports the model, False otherwise.
"""
return False
@abstractmethod
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that length of the prompt and answer is within the maximum token length of the PromptModel.
:param prompt: Prompt text to be sent to the generative model.
"""
pass
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stop_words: List[str]):
super().__init__()
self.stop_words = tokenizer.encode(stop_words, add_special_tokens=False, return_tensors="pt")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(torch.isin(input_ids[-1], self.stop_words[-1]))
class HFLocalInvocationLayer(PromptModelInvocationLayer):
"""
A subclass of the PromptModelInvocationLayer class. It loads a pre-trained model from Hugging Face and
passes a prepared prompt into that model.
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class,
as many variants of PromptModelInvocationLayer are possible and they may have different parameters.
"""
def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = True,
devices: Optional[List[Union[str, torch.device]]] = None,
**kwargs,
):
"""
Creates an instance of HFLocalInvocationLayer used to invoke local Hugging Face models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum length of the output text.
:param use_auth_token: The token to use as HTTP bearer authorization for remote files.
:param use_gpu: Whether to use GPU for inference.
:param device: The device to use for inference.
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
"""
super().__init__(model_name_or_path)
self.use_auth_token = use_auth_token
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(self.devices) > 1:
logger.warning(
"Multiple devices are not supported in %s inference, using the first device %s.",
self.__class__.__name__,
self.devices[0],
)
# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to Hugging Face pipeline documentation
# Do not use `device_map` AND `device` at the same time as they will conflict
model_input_kwargs = {
key: kwargs[key]
for key in [
"model_kwargs",
"trust_remote_code",
"revision",
"feature_extractor",
"tokenizer",
"config",
"use_fast",
"torch_dtype",
"device_map",
]
if key in kwargs
}
# flatten model_kwargs one level
if "model_kwargs" in model_input_kwargs:
mkwargs = model_input_kwargs.pop("model_kwargs")
model_input_kwargs.update(mkwargs)
torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
model_input_kwargs["torch_dtype"] = torch_dtype_resolved
if len(model_input_kwargs) > 0:
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
self.pipe = pipeline(
"text2text-generation",
model=model_name_or_path,
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
use_auth_token=self.use_auth_token,
model_kwargs=model_input_kwargs,
)
# This is how the default max_length is determined for Text2TextGenerationPipeline shown here
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/pipelines/text2text_generation.html
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated text using the local Hugging Face transformers model
:return: A list of generated text.
Note: Only kwargs relevant to Text2TextGenerationPipeline are passed to Hugging Face as model_input_kwargs.
Other kwargs are ignored.
"""
output: List[Dict[str, str]] = []
stop_words = kwargs.pop("stop_words", None)
top_k = kwargs.pop("top_k", None)
if kwargs and "prompt" in kwargs:
prompt = kwargs.pop("prompt")
# Consider only Text2TextGenerationPipeline relevant, ignore others
# For more details refer to Hugging Face Text2TextGenerationPipeline documentation
# TODO resolve these kwargs from the pipeline signature
model_input_kwargs = {
key: kwargs[key]
for key in ["return_tensors", "return_text", "clean_up_tokenization_spaces", "truncation"]
if key in kwargs
}
if stop_words:
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words)
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
if top_k:
model_input_kwargs["num_return_sequences"] = top_k
model_input_kwargs["num_beams"] = top_k
output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
if stop_words:
# Although HF generates text until stop words are encountered unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
for idx, _ in enumerate(generated_texts):
for stop_word in stop_words:
generated_texts[idx] = generated_texts[idx].replace(stop_word, "").strip()
return generated_texts
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
"""
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length:
return prompt
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens such that the prompt length and "
"answer length (%s tokens) fits within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off",
n_prompt_tokens,
self.pipe.tokenizer.model_max_length - n_answer_tokens,
n_answer_tokens,
self.pipe.tokenizer.model_max_length,
)
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens]
)
return decoded_string
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
try:
config = AutoConfig.from_pretrained(model_name_or_path)
except OSError:
# This is needed so OpenAI models are skipped over
return False
if not all(m in model_name_or_path for m in ["flan", "t5"]):
logger.warning(
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
"Many of the default prompts and PromptTemplates will likely not work as intended. "
"Use custom prompts and PromptTemplates specific to the %s model",
model_name_or_path,
)
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
return config.architectures[0] in supported_models
class OpenAIInvocationLayer(PromptModelInvocationLayer):
"""
PromptModelInvocationLayer implementation for OpenAI's GPT-3 InstructGPT models. Invocations are made using REST API.
See [OpenAI GPT-3](https://platform.openai.com/docs/models/gpt-3) for more details.
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class
as many variants of PromptModelInvocationLayer are possible and they may have different parameters.
"""
def __init__(
self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs
):
"""
Creates an instance of OpenAIInvocationLayer for OpenAI's GPT-3 InstructGPT models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum length of the output text.
:param api_key: The OpenAI API key.
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
super().__init__(model_name_or_path)
if not isinstance(api_key, str) or len(api_key) == 0:
raise OpenAIError(
f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one."
)
self.api_key = api_key
# 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create.
# max_length must be set otherwise OpenAIInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or 16
# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to OpenAI documentation
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"suffix",
"max_tokens",
"temperature",
"top_p",
"n",
"logprobs",
"echo",
"stop",
"presence_penalty",
"frequency_penalty",
"best_of",
"logit_bias",
]
if key in kwargs
}
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(
model_name=self.model_name_or_path
)
self.max_tokens_limit = max_tokens_limit
self._tokenizer = load_openai_tokenizer(tokenizer_name=tokenizer_name)
@property
def url(self) -> str:
return "https://api.openai.com/v1/completions"
@property
def headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
# we use keyword stop_words but OpenAI uses stop
if "stop_words" in kwargs:
kwargs["stop"] = kwargs.pop("stop_words")
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
payload = {
"model": self.model_name_or_path,
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": False, # no support for streaming
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_text_completion_answers(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
return responses
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
"""
n_prompt_tokens = count_openai_tokens(prompt, self._tokenizer)
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.max_tokens_limit:
return prompt
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens such that the prompt length and "
"answer length (%s tokens) fits within the max token limit (%s tokens). "
"Reduce the length of the prompt to prevent it from being cut off.",
n_prompt_tokens,
self.max_tokens_limit - n_answer_tokens,
n_answer_tokens,
self.max_tokens_limit,
)
if USE_TIKTOKEN:
tokenized_payload = self._tokenizer.encode(prompt)
decoded_string = self._tokenizer.decode(tokenized_payload[: self.max_tokens_limit - n_answer_tokens])
else:
tokenized_payload = self._tokenizer.tokenize(prompt)
decoded_string = self._tokenizer.convert_tokens_to_string(
tokenized_payload[: self.max_tokens_limit - n_answer_tokens]
)
return decoded_string
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return valid_model and kwargs.get("azure_base_url") is None
class AzureOpenAIInvocationLayer(OpenAIInvocationLayer):
"""
Azure OpenAI Invocation Layer
This layer is used to invoke the OpenAI API on Azure. It is essentially the same as the OpenAIInvocationLayer
with additional two parameters: azure_base_url and azure_deployment_name. The azure_base_url is the URL of the Azure OpenAI
endpoint and the azure_deployment_name is the name of the deployment.
"""
def __init__(
self,
azure_base_url: str,
azure_deployment_name: str,
api_key: str,
api_version: str = "2022-12-01",
model_name_or_path: str = "text-davinci-003",
max_length: Optional[int] = 100,
**kwargs,
):
super().__init__(api_key, model_name_or_path, max_length, **kwargs)
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.api_version = api_version
@property
def url(self) -> str:
return f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}"
@property
def headers(self) -> Dict[str, str]:
return {"api-key": self.api_key, "Content-Type": "application/json"}
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Ensures Azure OpenAI Invocation Layer is selected when azure_base_url and azure_deployment_name are provided in
addition to a list of supported models.
"""
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return (
valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None
)
class PromptModel(BaseComponent): class PromptModel(BaseComponent):
""" """
The PromptModel class is a component that uses a pre-trained model to perform tasks based on a prompt. Out of The PromptModel class is a component that uses a pre-trained model to perform tasks based on a prompt. Out of
@ -653,6 +186,7 @@ class PromptModel(BaseComponent):
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None, use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, torch.device]]] = None, devices: Optional[List[Union[str, torch.device]]] = None,
invocation_layer_class: Optional[str] = None,
model_kwargs: Optional[Dict] = None, model_kwargs: Optional[Dict] = None,
): ):
""" """
@ -664,6 +198,8 @@ class PromptModel(BaseComponent):
:param use_auth_token: The Hugging Face token to use. :param use_auth_token: The Hugging Face token to use.
:param use_gpu: Whether to use GPU or not. :param use_gpu: Whether to use GPU or not.
:param devices: The devices to use where the model is loaded. :param devices: The devices to use where the model is loaded.
:param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the
path from a modules global scope to the class. If None, known invocation layers are used.
:param model_kwargs: Additional keyword arguments passed to the underlying model. :param model_kwargs: Additional keyword arguments passed to the underlying model.
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
@ -681,11 +217,28 @@ class PromptModel(BaseComponent):
self.model_kwargs = model_kwargs if model_kwargs else {} self.model_kwargs = model_kwargs if model_kwargs else {}
self.invocation_layers: List[Type[PromptModelInvocationLayer]] = [] self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers()
if invocation_layer_class:
klass: Optional[Type[PromptModelInvocationLayer]] = None
if isinstance(invocation_layer_class, str):
# try to find the invocation_layer_class provider class
search_path: List[str] = [
f"haystack.nodes.prompt.providers.{invocation_layer_class}",
invocation_layer_class,
]
klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore
self.register(HFLocalInvocationLayer) # pylint: disable=W0108 if not klass:
self.register(OpenAIInvocationLayer) # pylint: disable=W0108 raise ValueError(
self.register(AzureOpenAIInvocationLayer) # pylint: disable=W0108 f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. "
f"Make sure to pass the full path to the class."
)
if not issubclass(klass, PromptModelInvocationLayer):
raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.")
logger.info("Registering custom invocation layer class %s", klass)
self.register(klass)
self.model_invocation_layer = self.create_invocation_layer() self.model_invocation_layer = self.create_invocation_layer()
@ -698,15 +251,17 @@ class PromptModel(BaseComponent):
} }
all_kwargs = {**self.model_kwargs, **kwargs} all_kwargs = {**self.model_kwargs, **kwargs}
for invocation_layer in self.invocation_layers: # search all invocation layer classes and find the first one that supports the model,
# then create an instance of that invocation layer
for invocation_layer in self.invocation_layer_classes:
if invocation_layer.supports(self.model_name_or_path, **all_kwargs): if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
return invocation_layer( return invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
) )
raise ValueError( raise ValueError(
f"Model {self.model_name_or_path} is not supported - no invocation layer found." f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
f" Currently supported models are: {self.invocation_layers}" f" Currently supported invocation layers are: {self.invocation_layer_classes}"
f" Register a new invocation layer for {self.model_name_or_path} using the register method." f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init."
) )
def register(self, invocation_layer: Type[PromptModelInvocationLayer]): def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
@ -714,7 +269,7 @@ class PromptModel(BaseComponent):
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface. matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
""" """
self.invocation_layers.append(invocation_layer) self.invocation_layer_classes.append(invocation_layer)
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]: def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
""" """

View File

@ -0,0 +1,481 @@
import logging
from abc import abstractmethod
from typing import Dict, List, Optional, Union, Type
import torch
from transformers import (
pipeline,
AutoConfig,
StoppingCriteriaList,
StoppingCriteria,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
from haystack.errors import OpenAIError
from haystack.modeling.utils import initialize_device_settings
from haystack.utils.openai_utils import (
USE_TIKTOKEN,
openai_request,
_openai_text_completion_tokenization_details,
load_openai_tokenizer,
_check_openai_text_completion_answers,
count_openai_tokens,
)
logger = logging.getLogger(__name__)
class PromptModelInvocationLayer:
"""
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
The implementation can be a simple invocation on the underlying model running in a local runtime, or
could be even remote, for example, a call to a remote API endpoint.
"""
def __init__(self, model_name_or_path: str, **kwargs):
"""
Creates a new PromptModelInvocationLayer instance.
:param model_name_or_path: The name or path of the underlying model.
:param kwargs: Additional keyword arguments passed to the underlying model.
"""
if model_name_or_path is None or len(model_name_or_path) == 0:
raise ValueError("model_name_or_path cannot be None or empty string")
self.model_name_or_path = model_name_or_path
@abstractmethod
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated text using the underlying model.
:return: A list of generated text.
"""
pass
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Checks if the given model is supported by this invocation layer.
:param model_name_or_path: The name or path of the model.
:param kwargs: additional keyword arguments passed to the underlying model which might be used to determine
if the model is supported.
:return: True if this invocation layer supports the model, False otherwise.
"""
return False
@abstractmethod
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that length of the prompt and answer is within the maximum token length of the PromptModel.
:param prompt: Prompt text to be sent to the generative model.
"""
pass
def known_providers() -> List[Type[PromptModelInvocationLayer]]:
return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer]
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stop_words: List[str]):
super().__init__()
self.stop_words = tokenizer.encode(stop_words, add_special_tokens=False, return_tensors="pt")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(torch.isin(input_ids[-1], self.stop_words[-1]))
class HFLocalInvocationLayer(PromptModelInvocationLayer):
"""
A subclass of the PromptModelInvocationLayer class. It loads a pre-trained model from Hugging Face and
passes a prepared prompt into that model.
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class,
as many variants of PromptModelInvocationLayer are possible and they may have different parameters.
"""
def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = True,
devices: Optional[List[Union[str, torch.device]]] = None,
**kwargs,
):
"""
Creates an instance of HFLocalInvocationLayer used to invoke local Hugging Face models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum length of the output text.
:param use_auth_token: The token to use as HTTP bearer authorization for remote files.
:param use_gpu: Whether to use GPU for inference.
:param device: The device to use for inference.
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
For more details about these kwargs, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
"""
super().__init__(model_name_or_path)
self.use_auth_token = use_auth_token
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(self.devices) > 1:
logger.warning(
"Multiple devices are not supported in %s inference, using the first device %s.",
self.__class__.__name__,
self.devices[0],
)
# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to Hugging Face pipeline documentation
# Do not use `device_map` AND `device` at the same time as they will conflict
model_input_kwargs = {
key: kwargs[key]
for key in [
"model_kwargs",
"trust_remote_code",
"revision",
"feature_extractor",
"tokenizer",
"config",
"use_fast",
"torch_dtype",
"device_map",
]
if key in kwargs
}
# flatten model_kwargs one level
if "model_kwargs" in model_input_kwargs:
mkwargs = model_input_kwargs.pop("model_kwargs")
model_input_kwargs.update(mkwargs)
torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
model_input_kwargs["torch_dtype"] = torch_dtype_resolved
if len(model_input_kwargs) > 0:
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
self.pipe = pipeline(
"text2text-generation",
model=model_name_or_path,
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
use_auth_token=self.use_auth_token,
model_kwargs=model_input_kwargs,
)
# This is how the default max_length is determined for Text2TextGenerationPipeline shown here
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/pipelines/text2text_generation.html
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length
def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated text using the local Hugging Face transformers model
:return: A list of generated text.
Note: Only kwargs relevant to Text2TextGenerationPipeline are passed to Hugging Face as model_input_kwargs.
Other kwargs are ignored.
"""
output: List[Dict[str, str]] = []
stop_words = kwargs.pop("stop_words", None)
top_k = kwargs.pop("top_k", None)
if kwargs and "prompt" in kwargs:
prompt = kwargs.pop("prompt")
# Consider only Text2TextGenerationPipeline relevant, ignore others
# For more details refer to Hugging Face Text2TextGenerationPipeline documentation
# TODO resolve these kwargs from the pipeline signature
model_input_kwargs = {
key: kwargs[key]
for key in ["return_tensors", "return_text", "clean_up_tokenization_spaces", "truncation"]
if key in kwargs
}
if stop_words:
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words)
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
if top_k:
model_input_kwargs["num_return_sequences"] = top_k
model_input_kwargs["num_beams"] = top_k
output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
if stop_words:
# Although HF generates text until stop words are encountered unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
for idx, _ in enumerate(generated_texts):
for stop_word in stop_words:
generated_texts[idx] = generated_texts[idx].replace(stop_word, "").strip()
return generated_texts
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
"""
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.pipe.tokenizer.model_max_length:
return prompt
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens such that the prompt length and "
"answer length (%s tokens) fits within the max token limit (%s tokens). "
"Shorten the prompt to prevent it from being cut off",
n_prompt_tokens,
self.pipe.tokenizer.model_max_length - n_answer_tokens,
n_answer_tokens,
self.pipe.tokenizer.model_max_length,
)
tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: self.pipe.tokenizer.model_max_length - n_answer_tokens]
)
return decoded_string
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
try:
config = AutoConfig.from_pretrained(model_name_or_path)
except OSError:
# This is needed so OpenAI models are skipped over
return False
if not all(m in model_name_or_path for m in ["flan", "t5"]):
logger.warning(
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
"Many of the default prompts and PromptTemplates will likely not work as intended. "
"Use custom prompts and PromptTemplates specific to the %s model",
model_name_or_path,
)
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
return config.architectures[0] in supported_models
class OpenAIInvocationLayer(PromptModelInvocationLayer):
"""
PromptModelInvocationLayer implementation for OpenAI's GPT-3 InstructGPT models. Invocations are made using REST API.
See [OpenAI GPT-3](https://platform.openai.com/docs/models/gpt-3) for more details.
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class
as many variants of PromptModelInvocationLayer are possible and they may have different parameters.
"""
def __init__(
self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs
):
"""
Creates an instance of OpenAIInvocationLayer for OpenAI's GPT-3 InstructGPT models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum length of the output text.
:param api_key: The OpenAI API key.
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of OpenAIInvocationLayer might receive some unrelated
kwargs. Only the kwargs relevant to OpenAIInvocationLayer are considered. The list of OpenAI-relevant
kwargs includes: suffix, temperature, top_p, presence_penalty, frequency_penalty, best_of, n, max_tokens,
logit_bias, stop, echo, and logprobs. For more details about these kwargs, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
super().__init__(model_name_or_path)
if not isinstance(api_key, str) or len(api_key) == 0:
raise OpenAIError(
f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one."
)
self.api_key = api_key
# 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create.
# max_length must be set otherwise OpenAIInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or 16
# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to OpenAI documentation
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"suffix",
"max_tokens",
"temperature",
"top_p",
"n",
"logprobs",
"echo",
"stop",
"presence_penalty",
"frequency_penalty",
"best_of",
"logit_bias",
]
if key in kwargs
}
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(
model_name=self.model_name_or_path
)
self.max_tokens_limit = max_tokens_limit
self._tokenizer = load_openai_tokenizer(tokenizer_name=tokenizer_name)
@property
def url(self) -> str:
return "https://api.openai.com/v1/completions"
@property
def headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
# we use keyword stop_words but OpenAI uses stop
if "stop_words" in kwargs:
kwargs["stop"] = kwargs.pop("stop_words")
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
payload = {
"model": self.model_name_or_path,
"prompt": prompt,
"suffix": kwargs_with_defaults.get("suffix", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": False, # no support for streaming
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_text_completion_answers(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
return responses
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
"""
n_prompt_tokens = count_openai_tokens(prompt, self._tokenizer)
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.max_tokens_limit:
return prompt
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens such that the prompt length and "
"answer length (%s tokens) fits within the max token limit (%s tokens). "
"Reduce the length of the prompt to prevent it from being cut off.",
n_prompt_tokens,
self.max_tokens_limit - n_answer_tokens,
n_answer_tokens,
self.max_tokens_limit,
)
if USE_TIKTOKEN:
tokenized_payload = self._tokenizer.encode(prompt)
decoded_string = self._tokenizer.decode(tokenized_payload[: self.max_tokens_limit - n_answer_tokens])
else:
tokenized_payload = self._tokenizer.tokenize(prompt)
decoded_string = self._tokenizer.convert_tokens_to_string(
tokenized_payload[: self.max_tokens_limit - n_answer_tokens]
)
return decoded_string
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return valid_model and kwargs.get("azure_base_url") is None
class AzureOpenAIInvocationLayer(OpenAIInvocationLayer):
"""
Azure OpenAI Invocation Layer
This layer is used to invoke the OpenAI API on Azure. It is essentially the same as the OpenAIInvocationLayer
with additional two parameters: azure_base_url and azure_deployment_name. The azure_base_url is the URL of the Azure OpenAI
endpoint and the azure_deployment_name is the name of the deployment.
"""
def __init__(
self,
azure_base_url: str,
azure_deployment_name: str,
api_key: str,
api_version: str = "2022-12-01",
model_name_or_path: str = "text-davinci-003",
max_length: Optional[int] = 100,
**kwargs,
):
super().__init__(api_key, model_name_or_path, max_length, **kwargs)
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.api_version = api_version
@property
def url(self) -> str:
return f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}"
@property
def headers(self) -> Dict[str, str]:
return {"api-key": self.api_key, "Content-Type": "application/json"}
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Ensures Azure OpenAI Invocation Layer is selected when azure_base_url and azure_deployment_name are provided in
addition to a list of supported models.
"""
valid_model = any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path)
return (
valid_model and kwargs.get("azure_base_url") is not None and kwargs.get("azure_deployment_name") is not None
)

View File

@ -8,7 +8,8 @@ import torch
from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_node import HFLocalInvocationLayer from haystack.nodes.prompt import PromptModelInvocationLayer
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
def skip_test_for_invalid_key(prompt_model): def skip_test_for_invalid_key(prompt_model):
@ -16,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model):
pytest.skip("No API key found, skipping test") pytest.skip("No API key found, skipping test")
class CustomInvocationLayer(PromptModelInvocationLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def invoke(self, *args, **kwargs):
return ["fake_response"]
def _ensure_token_limit(self, prompt: str) -> str:
return prompt
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
return model_name_or_path == "fake_model"
@pytest.fixture @pytest.fixture
def get_api_key(request): def get_api_key(request):
if request.param == "openai": if request.param == "openai":
@ -71,6 +87,15 @@ def test_prompt_template_repr():
assert str(p) == desired_repr assert str(p) == desired_repr
@pytest.mark.unit
def test_prompt_node_with_custom_invocation_layer_from_string():
model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer")
pn = PromptNode(model_name_or_path=model)
output = pn("Some fake invocation")
assert output == ["fake_response"]
@pytest.mark.integration @pytest.mark.integration
def test_create_prompt_model(): def test_create_prompt_model():
model = PromptModel("google/flan-t5-small") model = PromptModel("google/flan-t5-small")
@ -897,7 +922,7 @@ class TestRunBatch:
assert isinstance(result["results"][0][0], str) assert isinstance(result["results"][0][0], str)
@pytest.mark.integration @pytest.mark.unit
def test_HFLocalInvocationLayer_supports(): def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum") assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B") assert HFLocalInvocationLayer.supports("bigscience/T0_3B")