mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 11:38:20 +00:00
feat: Add Hugging Face inferencing PromptNode layer (#4641)
This commit is contained in:
parent
6a5acaa1e2
commit
1dcac11133
@ -24,4 +24,3 @@ renderer:
|
||||
add_method_class_prefix: true
|
||||
add_member_class_prefix: false
|
||||
filename: pipelines_api.md
|
||||
|
||||
|
||||
@ -26,4 +26,3 @@ renderer:
|
||||
add_method_class_prefix: true
|
||||
add_member_class_prefix: false
|
||||
filename: prompt_node_api.md
|
||||
|
||||
|
||||
@ -24,4 +24,3 @@ renderer:
|
||||
add_method_class_prefix: true
|
||||
add_member_class_prefix: false
|
||||
filename: whisper_api.md
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from haystack import Pipeline, BaseComponent, Answer, Document
|
||||
from haystack.telemetry import send_event
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
from haystack.agents.types import Color
|
||||
from haystack.agents.utils import print_text
|
||||
from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||
from haystack.nodes.prompt.invocation_layer import TokenStreamingHandler
|
||||
@ -180,7 +180,12 @@ class Agent:
|
||||
self.max_steps = max_steps
|
||||
self.tool_pattern = tool_pattern
|
||||
self.final_answer_pattern = final_answer_pattern
|
||||
self.add_default_logging_callbacks()
|
||||
# Resolve model name to check if it's a streaming model
|
||||
if isinstance(self.prompt_node.model_name_or_path, str):
|
||||
model_name = self.prompt_node.model_name_or_path
|
||||
else:
|
||||
model_name = self.prompt_node.model_name_or_path.model_name_or_path
|
||||
self.add_default_logging_callbacks(streaming=any(m for m in STREAMING_CAPABLE_MODELS if m in model_name))
|
||||
self.hash = None
|
||||
self.last_hash = None
|
||||
self.update_hash()
|
||||
@ -197,7 +202,7 @@ class Agent:
|
||||
logger.debug("Telemetry exception: %s", str(exc))
|
||||
self.hash = "[an exception occurred during hashing]"
|
||||
|
||||
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
|
||||
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None:
|
||||
def on_tool_finish(
|
||||
tool_output: str,
|
||||
color: Optional[Color] = None,
|
||||
@ -215,7 +220,13 @@ class Agent:
|
||||
|
||||
self.callback_manager.on_tool_finish += on_tool_finish
|
||||
self.callback_manager.on_agent_start += on_agent_start
|
||||
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
|
||||
|
||||
if streaming:
|
||||
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
|
||||
else:
|
||||
self.callback_manager.on_agent_step += lambda agent_step: print_text(
|
||||
agent_step.prompt_node_response, color=agent_color
|
||||
)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
"""
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Optional
|
||||
|
||||
from haystack.agents.types import Color
|
||||
|
||||
STREAMING_CAPABLE_MODELS = ["davinci"]
|
||||
|
||||
|
||||
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
|
||||
"""
|
||||
|
||||
@ -179,3 +179,27 @@ class ImageToTextError(NodeError):
|
||||
|
||||
def __init__(self, message: Optional[str] = None):
|
||||
super().__init__(message=message)
|
||||
|
||||
|
||||
class HuggingFaceInferenceError(NodeError):
|
||||
"""Exception for issues that occur in the HuggingFace inference node"""
|
||||
|
||||
def __init__(
|
||||
self, message: Optional[str] = None, status_code: Optional[int] = None, send_message_in_event: bool = False
|
||||
):
|
||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HuggingFaceInferenceLimitError(HuggingFaceInferenceError):
|
||||
"""Exception for issues that occur in the HuggingFace inference node due to rate limiting"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event)
|
||||
|
||||
|
||||
class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError):
|
||||
"""Exception for issues that occur in the HuggingFace inference node due to unauthorized access"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event)
|
||||
|
||||
@ -3,4 +3,5 @@ from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLay
|
||||
from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||
|
||||
@ -30,7 +30,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "google/flan-t5-base",
|
||||
max_length: Optional[int] = 100,
|
||||
max_length: int = 100,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
use_gpu: Optional[bool] = True,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
@ -247,8 +247,9 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
except RuntimeError:
|
||||
# This will fail for all non-HF models
|
||||
return False
|
||||
|
||||
return task_name in ["text2text-generation", "text-generation"]
|
||||
# if we are using an api_key it could be HF inference point
|
||||
using_api_key = kwargs.get("api_key", None) is not None
|
||||
return not using_api_key and task_name in ["text2text-generation", "text-generation"]
|
||||
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
|
||||
209
haystack/nodes/prompt/invocation_layer/hugging_face_inference.py
Normal file
209
haystack/nodes/prompt/invocation_layer/hugging_face_inference.py
Normal file
@ -0,0 +1,209 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Dict, Union, List, Any
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from transformers.pipelines import get_task
|
||||
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||
from haystack.errors import (
|
||||
HuggingFaceInferenceLimitError,
|
||||
HuggingFaceInferenceUnauthorizedError,
|
||||
HuggingFaceInferenceError,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
|
||||
from haystack.utils.requests import request_with_retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HF_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
|
||||
HF_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
|
||||
|
||||
|
||||
class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
|
||||
"""
|
||||
A PromptModelInvocationLayer that invokes Hugging Face remote Inference Endpoint and API Inference to prompt the model.
|
||||
For more details see Hugging Face Inference API [documentation](https://huggingface.co/docs/api-inference/index)
|
||||
and Hugging Face Inference Endpoints [documentation](https://huggingface.co/inference-endpoints)
|
||||
|
||||
The Inference API is free to use, and rate limited. If you need an inference solution for production, you can use
|
||||
Inference Endpoints service.
|
||||
|
||||
See documentation for more details: https://huggingface.co/docs/inference-endpoints
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
|
||||
"""
|
||||
Creates an instance of HFInferenceEndpointInvocationLayer
|
||||
:param model_name_or_path: can be either:
|
||||
a) Hugging Face Inference model name (i.e. google/flan-t5-xxl)
|
||||
b) Hugging Face Inference Endpoint URL (i.e. e.g. https://<your-unique-deployment-id>.us-east-1.aws.endpoints.huggingface.cloud)
|
||||
:param max_length: The maximum length of the output text.
|
||||
:param api_key: The Hugging Face API token. You’ll need to provide your user token which can
|
||||
be found in your Hugging Face account [settings](https://huggingface.co/settings/tokens)
|
||||
"""
|
||||
super().__init__(model_name_or_path)
|
||||
valid_api_key = isinstance(api_key, str) and api_key
|
||||
if not valid_api_key:
|
||||
raise ValueError(
|
||||
f"api_key {api_key} must be a valid Hugging Face token. "
|
||||
f"Your token is available in your Hugging Face settings page."
|
||||
)
|
||||
valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path
|
||||
if not valid_model_name_or_path:
|
||||
raise ValueError(
|
||||
f"model_name_or_path {model_name_or_path} must be a valid Hugging Face inference endpoint URL."
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.max_length = max_length
|
||||
|
||||
# See https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
# for a list of supported parameters
|
||||
self.model_input_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in [
|
||||
"top_k",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"repetition_penalty",
|
||||
"max_new_tokens",
|
||||
"max_time",
|
||||
"return_full_text",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if HFInferenceEndpointInvocationLayer.is_inference_endpoint(self.model_name_or_path):
|
||||
# Inference Endpoint URL
|
||||
# i.e. https://o3x2xh3o4m47mxny.us-east-1.aws.endpoints.huggingface.cloud
|
||||
url = self.model_name_or_path
|
||||
|
||||
else:
|
||||
url = f"https://api-inference.huggingface.co/models/{self.model_name_or_path}"
|
||||
return url
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation.
|
||||
:return: The responses are being returned.
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f"No prompt provided. Model {self.model_name_or_path} requires prompt."
|
||||
f"Make sure to provide prompt in kwargs."
|
||||
)
|
||||
stop_words = kwargs.pop("stop_words", None)
|
||||
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
if "max_new_tokens" not in kwargs_with_defaults:
|
||||
kwargs_with_defaults["max_new_tokens"] = self.max_length
|
||||
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.pop("top_k")
|
||||
kwargs["num_return_sequences"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
# see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
accepted_params = [
|
||||
"top_p",
|
||||
"top_k",
|
||||
"temperature",
|
||||
"repetition_penalty",
|
||||
"max_new_tokens",
|
||||
"max_time",
|
||||
"return_full_text",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
params = {key: kwargs_with_defaults.get(key) for key in accepted_params if key in kwargs_with_defaults}
|
||||
generated_texts = self._post(data={"inputs": prompt, "parameters": params}, **kwargs)
|
||||
if stop_words:
|
||||
for idx, _ in enumerate(generated_texts):
|
||||
earliest_stop_word_idx = len(generated_texts[idx])
|
||||
for stop_word in stop_words:
|
||||
stop_word_idx = generated_texts[idx].find(stop_word)
|
||||
if stop_word_idx != -1:
|
||||
earliest_stop_word_idx = min(earliest_stop_word_idx, stop_word_idx)
|
||||
generated_texts[idx] = generated_texts[idx][:earliest_stop_word_idx]
|
||||
|
||||
return generated_texts
|
||||
|
||||
def _post(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
attempts: int = HF_RETRIES,
|
||||
status_codes: Optional[List[int]] = None,
|
||||
timeout: float = HF_TIMEOUT,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST invocation.
|
||||
:param data: The data to be sent to the model.
|
||||
:param attempts: The number of attempts to make.
|
||||
:param status_codes: The status codes to retry on.
|
||||
:param timeout: The timeout for the request.
|
||||
:return: The responses are being returned.
|
||||
"""
|
||||
generated_texts: List[str] = []
|
||||
if status_codes is None:
|
||||
status_codes = [429]
|
||||
try:
|
||||
response = request_with_retry(
|
||||
method="POST",
|
||||
status_codes=status_codes,
|
||||
attempts=attempts,
|
||||
url=self.url,
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
output = json.loads(response.text)
|
||||
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
except requests.HTTPError as err:
|
||||
res = err.response
|
||||
if res.status_code == 429:
|
||||
raise HuggingFaceInferenceLimitError(f"API rate limit exceeded: {res.text}")
|
||||
if res.status_code == 401:
|
||||
raise HuggingFaceInferenceUnauthorizedError(f"API key is invalid: {res.text}")
|
||||
|
||||
raise HuggingFaceInferenceError(
|
||||
f"HuggingFace Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
|
||||
status_code=res.status_code,
|
||||
)
|
||||
return generated_texts
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# TODO: new implementation incoming for all layers, let's omit this for now
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def is_inference_endpoint(model_name_or_path: str) -> bool:
|
||||
return model_name_or_path is not None and all(
|
||||
token in model_name_or_path for token in ["https://", "endpoints"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
if cls.is_inference_endpoint(model_name_or_path):
|
||||
return True
|
||||
else:
|
||||
# Check if the model is an HF inference API
|
||||
task_name: Optional[str] = None
|
||||
is_inference_api = False
|
||||
try:
|
||||
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
|
||||
is_inference_api = "api_key" in kwargs
|
||||
except RuntimeError:
|
||||
# This will fail for all non-HF models
|
||||
return False
|
||||
|
||||
return is_inference_api and task_name in ["text2text-generation", "text-generation"]
|
||||
@ -340,7 +340,7 @@ def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_update_hash():
|
||||
agent = Agent(prompt_node=mock.Mock(), prompt_template=mock.Mock())
|
||||
agent = Agent(prompt_node=MockPromptNode(), prompt_template=mock.Mock())
|
||||
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
|
||||
@ -363,6 +363,7 @@ class MockReader(BaseReader):
|
||||
class MockPromptNode(PromptNode):
|
||||
def __init__(self):
|
||||
self.default_prompt_template = None
|
||||
self.model_name_or_path = ""
|
||||
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
||||
return [""]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user