feat: Add Hugging Face inferencing PromptNode layer (#4641)

This commit is contained in:
Vladimir Blagojevic 2023-04-14 17:59:17 +02:00 committed by GitHub
parent 6a5acaa1e2
commit 1dcac11133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 257 additions and 11 deletions

View File

@ -24,4 +24,3 @@ renderer:
add_method_class_prefix: true
add_member_class_prefix: false
filename: pipelines_api.md

View File

@ -26,4 +26,3 @@ renderer:
add_method_class_prefix: true
add_member_class_prefix: false
filename: prompt_node_api.md

View File

@ -24,4 +24,3 @@ renderer:
add_method_class_prefix: true
add_member_class_prefix: false
filename: whisper_api.md

View File

@ -11,7 +11,7 @@ from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.telemetry import send_event
from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color
from haystack.agents.utils import print_text
from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS
from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.nodes.prompt.invocation_layer import TokenStreamingHandler
@ -180,7 +180,12 @@ class Agent:
self.max_steps = max_steps
self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern
self.add_default_logging_callbacks()
# Resolve model name to check if it's a streaming model
if isinstance(self.prompt_node.model_name_or_path, str):
model_name = self.prompt_node.model_name_or_path
else:
model_name = self.prompt_node.model_name_or_path.model_name_or_path
self.add_default_logging_callbacks(streaming=any(m for m in STREAMING_CAPABLE_MODELS if m in model_name))
self.hash = None
self.last_hash = None
self.update_hash()
@ -197,7 +202,7 @@ class Agent:
logger.debug("Telemetry exception: %s", str(exc))
self.hash = "[an exception occurred during hashing]"
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None:
def on_tool_finish(
tool_output: str,
color: Optional[Color] = None,
@ -215,7 +220,13 @@ class Agent:
self.callback_manager.on_tool_finish += on_tool_finish
self.callback_manager.on_agent_start += on_agent_start
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
if streaming:
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
else:
self.callback_manager.on_agent_step += lambda agent_step: print_text(
agent_step.prompt_node_response, color=agent_color
)
def add_tool(self, tool: Tool):
"""

View File

@ -2,6 +2,8 @@ from typing import Optional
from haystack.agents.types import Color
STREAMING_CAPABLE_MODELS = ["davinci"]
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
"""

View File

@ -179,3 +179,27 @@ class ImageToTextError(NodeError):
def __init__(self, message: Optional[str] = None):
super().__init__(message=message)
class HuggingFaceInferenceError(NodeError):
"""Exception for issues that occur in the HuggingFace inference node"""
def __init__(
self, message: Optional[str] = None, status_code: Optional[int] = None, send_message_in_event: bool = False
):
super().__init__(message=message, send_message_in_event=send_message_in_event)
self.status_code = status_code
class HuggingFaceInferenceLimitError(HuggingFaceInferenceError):
"""Exception for issues that occur in the HuggingFace inference node due to rate limiting"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event)
class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError):
"""Exception for issues that occur in the HuggingFace inference node due to unauthorized access"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event)

View File

@ -3,4 +3,5 @@ from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLay
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer

View File

@ -30,7 +30,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
max_length: int = 100,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = True,
devices: Optional[List[Union[str, torch.device]]] = None,
@ -247,8 +247,9 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
except RuntimeError:
# This will fail for all non-HF models
return False
return task_name in ["text2text-generation", "text-generation"]
# if we are using an api_key it could be HF inference point
using_api_key = kwargs.get("api_key", None) is not None
return not using_api_key and task_name in ["text2text-generation", "text-generation"]
class StopWordsCriteria(StoppingCriteria):

View File

@ -0,0 +1,209 @@
import json
import os
from typing import Optional, Dict, Union, List, Any
import logging
import requests
from transformers.pipelines import get_task
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
from haystack.errors import (
HuggingFaceInferenceLimitError,
HuggingFaceInferenceUnauthorizedError,
HuggingFaceInferenceError,
)
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
from haystack.utils.requests import request_with_retry
logger = logging.getLogger(__name__)
HF_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
HF_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
"""
A PromptModelInvocationLayer that invokes Hugging Face remote Inference Endpoint and API Inference to prompt the model.
For more details see Hugging Face Inference API [documentation](https://huggingface.co/docs/api-inference/index)
and Hugging Face Inference Endpoints [documentation](https://huggingface.co/inference-endpoints)
The Inference API is free to use, and rate limited. If you need an inference solution for production, you can use
Inference Endpoints service.
See documentation for more details: https://huggingface.co/docs/inference-endpoints
"""
def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
"""
Creates an instance of HFInferenceEndpointInvocationLayer
:param model_name_or_path: can be either:
a) Hugging Face Inference model name (i.e. google/flan-t5-xxl)
b) Hugging Face Inference Endpoint URL (i.e. e.g. https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud)
:param max_length: The maximum length of the output text.
:param api_key: The Hugging Face API token. Youll need to provide your user token which can
be found in your Hugging Face account [settings](https://huggingface.co/settings/tokens)
"""
super().__init__(model_name_or_path)
valid_api_key = isinstance(api_key, str) and api_key
if not valid_api_key:
raise ValueError(
f"api_key {api_key} must be a valid Hugging Face token. "
f"Your token is available in your Hugging Face settings page."
)
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
if not valid_model_name_or_path:
raise ValueError(
f"model_name_or_path {model_name_or_path} must be a valid Hugging Face inference endpoint URL."
)
self.api_key = api_key
self.max_length = max_length
# See https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
# for a list of supported parameters
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_new_tokens",
"max_time",
"return_full_text",
"num_return_sequences",
"do_sample",
]
if key in kwargs
}
@property
def url(self) -> str:
if HFInferenceEndpointInvocationLayer.is_inference_endpoint(self.model_name_or_path):
# Inference Endpoint URL
# i.e. https://o3x2xh3o4m47mxny.us-east-1.aws.endpoints.huggingface.cloud
url = self.model_name_or_path
else:
url = f"https://api-inference.huggingface.co/models/{self.model_name_or_path}"
return url
@property
def headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
:return: The responses are being returned.
"""
prompt = kwargs.get("prompt")
if not prompt:
raise ValueError(
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
f"Make sure to provide prompt in kwargs."
)
stop_words = kwargs.pop("stop_words", None)
kwargs_with_defaults = self.model_input_kwargs
if "max_new_tokens" not in kwargs_with_defaults:
kwargs_with_defaults["max_new_tokens"] = self.max_length
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["num_return_sequences"] = top_k
kwargs_with_defaults.update(kwargs)
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
accepted_params = [
"top_p",
"top_k",
"temperature",
"repetition_penalty",
"max_new_tokens",
"max_time",
"return_full_text",
"num_return_sequences",
"do_sample",
]
params = {key: kwargs_with_defaults.get(key) for key in accepted_params if key in kwargs_with_defaults}
generated_texts = self._post(data={"inputs": prompt, "parameters": params}, **kwargs)
if stop_words:
for idx, _ in enumerate(generated_texts):
earliest_stop_word_idx = len(generated_texts[idx])
for stop_word in stop_words:
stop_word_idx = generated_texts[idx].find(stop_word)
if stop_word_idx != -1:
earliest_stop_word_idx = min(earliest_stop_word_idx, stop_word_idx)
generated_texts[idx] = generated_texts[idx][:earliest_stop_word_idx]
return generated_texts
def _post(
self,
data: Dict[str, Any],
attempts: int = HF_RETRIES,
status_codes: Optional[List[int]] = None,
timeout: float = HF_TIMEOUT,
**kwargs,
) -> List[str]:
"""
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST invocation.
:param data: The data to be sent to the model.
:param attempts: The number of attempts to make.
:param status_codes: The status codes to retry on.
:param timeout: The timeout for the request.
:return: The responses are being returned.
"""
generated_texts: List[str] = []
if status_codes is None:
status_codes = [429]
try:
response = request_with_retry(
method="POST",
status_codes=status_codes,
attempts=attempts,
url=self.url,
headers=self.headers,
json=data,
timeout=timeout,
)
output = json.loads(response.text)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
except requests.HTTPError as err:
res = err.response
if res.status_code == 429:
raise HuggingFaceInferenceLimitError(f"API rate limit exceeded: {res.text}")
if res.status_code == 401:
raise HuggingFaceInferenceUnauthorizedError(f"API key is invalid: {res.text}")
raise HuggingFaceInferenceError(
f"HuggingFace Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
status_code=res.status_code,
)
return generated_texts
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# TODO: new implementation incoming for all layers, let's omit this for now
return prompt
@staticmethod
def is_inference_endpoint(model_name_or_path: str) -> bool:
return model_name_or_path is not None and all(
token in model_name_or_path for token in ["https://", "endpoints"]
)
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
if cls.is_inference_endpoint(model_name_or_path):
return True
else:
# Check if the model is an HF inference API
task_name: Optional[str] = None
is_inference_api = False
try:
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
is_inference_api = "api_key" in kwargs
except RuntimeError:
# This will fail for all non-HF models
return False
return is_inference_api and task_name in ["text2text-generation", "text-generation"]

View File

@ -340,7 +340,7 @@ def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
@pytest.mark.unit
def test_update_hash():
agent = Agent(prompt_node=mock.Mock(), prompt_template=mock.Mock())
agent = Agent(prompt_node=MockPromptNode(), prompt_template=mock.Mock())
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
agent.add_tool(
Tool(

View File

@ -363,6 +363,7 @@ class MockReader(BaseReader):
class MockPromptNode(PromptNode):
def __init__(self):
self.default_prompt_template = None
self.model_name_or_path = ""
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
return [""]