diff --git a/docs/pydoc/config/pipelines.yml b/docs/pydoc/config/pipelines.yml index 479ac5cfc..b484631ff 100644 --- a/docs/pydoc/config/pipelines.yml +++ b/docs/pydoc/config/pipelines.yml @@ -24,4 +24,3 @@ renderer: add_method_class_prefix: true add_member_class_prefix: false filename: pipelines_api.md - diff --git a/docs/pydoc/config/prompt-node.yml b/docs/pydoc/config/prompt-node.yml index 4358ddeb8..cac174d50 100644 --- a/docs/pydoc/config/prompt-node.yml +++ b/docs/pydoc/config/prompt-node.yml @@ -26,4 +26,3 @@ renderer: add_method_class_prefix: true add_member_class_prefix: false filename: prompt_node_api.md - diff --git a/docs/pydoc/config/whisper.yml b/docs/pydoc/config/whisper.yml index 89018d493..0c7fbab52 100644 --- a/docs/pydoc/config/whisper.yml +++ b/docs/pydoc/config/whisper.yml @@ -24,4 +24,3 @@ renderer: add_method_class_prefix: true add_member_class_prefix: false filename: whisper_api.md - diff --git a/haystack/agents/base.py b/haystack/agents/base.py index f03e6fc2a..33b181abb 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -11,7 +11,7 @@ from haystack import Pipeline, BaseComponent, Answer, Document from haystack.telemetry import send_event from haystack.agents.agent_step import AgentStep from haystack.agents.types import Color -from haystack.agents.utils import print_text +from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS from haystack.errors import AgentError from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate from haystack.nodes.prompt.invocation_layer import TokenStreamingHandler @@ -180,7 +180,12 @@ class Agent: self.max_steps = max_steps self.tool_pattern = tool_pattern self.final_answer_pattern = final_answer_pattern - self.add_default_logging_callbacks() + # Resolve model name to check if it's a streaming model + if isinstance(self.prompt_node.model_name_or_path, str): + model_name = self.prompt_node.model_name_or_path + else: + model_name = self.prompt_node.model_name_or_path.model_name_or_path + self.add_default_logging_callbacks(streaming=any(m for m in STREAMING_CAPABLE_MODELS if m in model_name)) self.hash = None self.last_hash = None self.update_hash() @@ -197,7 +202,7 @@ class Agent: logger.debug("Telemetry exception: %s", str(exc)) self.hash = "[an exception occurred during hashing]" - def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None: + def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None: def on_tool_finish( tool_output: str, color: Optional[Color] = None, @@ -215,7 +220,13 @@ class Agent: self.callback_manager.on_tool_finish += on_tool_finish self.callback_manager.on_agent_start += on_agent_start - self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) + + if streaming: + self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) + else: + self.callback_manager.on_agent_step += lambda agent_step: print_text( + agent_step.prompt_node_response, color=agent_color + ) def add_tool(self, tool: Tool): """ diff --git a/haystack/agents/utils.py b/haystack/agents/utils.py index f86f9b58d..792aa533c 100644 --- a/haystack/agents/utils.py +++ b/haystack/agents/utils.py @@ -2,6 +2,8 @@ from typing import Optional from haystack.agents.types import Color +STREAMING_CAPABLE_MODELS = ["davinci"] + def print_text(text: str, end="", color: Optional[Color] = None) -> None: """ diff --git a/haystack/errors.py b/haystack/errors.py index d9cf31f99..23064d6c4 100644 --- a/haystack/errors.py +++ b/haystack/errors.py @@ -179,3 +179,27 @@ class ImageToTextError(NodeError): def __init__(self, message: Optional[str] = None): super().__init__(message=message) + + +class HuggingFaceInferenceError(NodeError): + """Exception for issues that occur in the HuggingFace inference node""" + + def __init__( + self, message: Optional[str] = None, status_code: Optional[int] = None, send_message_in_event: bool = False + ): + super().__init__(message=message, send_message_in_event=send_message_in_event) + self.status_code = status_code + + +class HuggingFaceInferenceLimitError(HuggingFaceInferenceError): + """Exception for issues that occur in the HuggingFace inference node due to rate limiting""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event) + + +class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError): + """Exception for issues that occur in the HuggingFace inference node due to unauthorized access""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event) diff --git a/haystack/nodes/prompt/invocation_layer/__init__.py b/haystack/nodes/prompt/invocation_layer/__init__.py index c49fb72df..f8753cc24 100644 --- a/haystack/nodes/prompt/invocation_layer/__init__.py +++ b/haystack/nodes/prompt/invocation_layer/__init__.py @@ -3,4 +3,5 @@ from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLay from haystack.nodes.prompt.invocation_layer.chatgpt import ChatGPTInvocationLayer from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer +from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index e162b422d..bd7d2ed34 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -30,7 +30,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): def __init__( self, model_name_or_path: str = "google/flan-t5-base", - max_length: Optional[int] = 100, + max_length: int = 100, use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = True, devices: Optional[List[Union[str, torch.device]]] = None, @@ -247,8 +247,9 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): except RuntimeError: # This will fail for all non-HF models return False - - return task_name in ["text2text-generation", "text-generation"] + # if we are using an api_key it could be HF inference point + using_api_key = kwargs.get("api_key", None) is not None + return not using_api_key and task_name in ["text2text-generation", "text-generation"] class StopWordsCriteria(StoppingCriteria): diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py new file mode 100644 index 000000000..dc20f4646 --- /dev/null +++ b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py @@ -0,0 +1,209 @@ +import json +import os +from typing import Optional, Dict, Union, List, Any +import logging + +import requests +from transformers.pipelines import get_task + +from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES +from haystack.errors import ( + HuggingFaceInferenceLimitError, + HuggingFaceInferenceUnauthorizedError, + HuggingFaceInferenceError, +) +from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer +from haystack.utils.requests import request_with_retry + +logger = logging.getLogger(__name__) +HF_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) +HF_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) + + +class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer): + """ + A PromptModelInvocationLayer that invokes Hugging Face remote Inference Endpoint and API Inference to prompt the model. + For more details see Hugging Face Inference API [documentation](https://huggingface.co/docs/api-inference/index) + and Hugging Face Inference Endpoints [documentation](https://huggingface.co/inference-endpoints) + + The Inference API is free to use, and rate limited. If you need an inference solution for production, you can use + Inference Endpoints service. + + See documentation for more details: https://huggingface.co/docs/inference-endpoints + + """ + + def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs): + """ + Creates an instance of HFInferenceEndpointInvocationLayer + :param model_name_or_path: can be either: + a) Hugging Face Inference model name (i.e. google/flan-t5-xxl) + b) Hugging Face Inference Endpoint URL (i.e. e.g. https://.us-east-1.aws.endpoints.huggingface.cloud) + :param max_length: The maximum length of the output text. + :param api_key: The Hugging Face API token. You’ll need to provide your user token which can + be found in your Hugging Face account [settings](https://huggingface.co/settings/tokens) + """ + super().__init__(model_name_or_path) + valid_api_key = isinstance(api_key, str) and api_key + if not valid_api_key: + raise ValueError( + f"api_key {api_key} must be a valid Hugging Face token. " + f"Your token is available in your Hugging Face settings page." + ) + valid_model_name_or_path = isinstance(model_name_or_path, str) and model_name_or_path + if not valid_model_name_or_path: + raise ValueError( + f"model_name_or_path {model_name_or_path} must be a valid Hugging Face inference endpoint URL." + ) + self.api_key = api_key + self.max_length = max_length + + # See https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task + # for a list of supported parameters + self.model_input_kwargs = { + key: kwargs[key] + for key in [ + "top_k", + "top_p", + "temperature", + "repetition_penalty", + "max_new_tokens", + "max_time", + "return_full_text", + "num_return_sequences", + "do_sample", + ] + if key in kwargs + } + + @property + def url(self) -> str: + if HFInferenceEndpointInvocationLayer.is_inference_endpoint(self.model_name_or_path): + # Inference Endpoint URL + # i.e. https://o3x2xh3o4m47mxny.us-east-1.aws.endpoints.huggingface.cloud + url = self.model_name_or_path + + else: + url = f"https://api-inference.huggingface.co/models/{self.model_name_or_path}" + return url + + @property + def headers(self) -> Dict[str, str]: + return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + + def invoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. It takes in a prompt and returns a list of responses using a REST invocation. + :return: The responses are being returned. + """ + prompt = kwargs.get("prompt") + if not prompt: + raise ValueError( + f"No prompt provided. Model {self.model_name_or_path} requires prompt." + f"Make sure to provide prompt in kwargs." + ) + stop_words = kwargs.pop("stop_words", None) + + kwargs_with_defaults = self.model_input_kwargs + if "max_new_tokens" not in kwargs_with_defaults: + kwargs_with_defaults["max_new_tokens"] = self.max_length + + if "top_k" in kwargs: + top_k = kwargs.pop("top_k") + kwargs["num_return_sequences"] = top_k + kwargs_with_defaults.update(kwargs) + # see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task + accepted_params = [ + "top_p", + "top_k", + "temperature", + "repetition_penalty", + "max_new_tokens", + "max_time", + "return_full_text", + "num_return_sequences", + "do_sample", + ] + params = {key: kwargs_with_defaults.get(key) for key in accepted_params if key in kwargs_with_defaults} + generated_texts = self._post(data={"inputs": prompt, "parameters": params}, **kwargs) + if stop_words: + for idx, _ in enumerate(generated_texts): + earliest_stop_word_idx = len(generated_texts[idx]) + for stop_word in stop_words: + stop_word_idx = generated_texts[idx].find(stop_word) + if stop_word_idx != -1: + earliest_stop_word_idx = min(earliest_stop_word_idx, stop_word_idx) + generated_texts[idx] = generated_texts[idx][:earliest_stop_word_idx] + + return generated_texts + + def _post( + self, + data: Dict[str, Any], + attempts: int = HF_RETRIES, + status_codes: Optional[List[int]] = None, + timeout: float = HF_TIMEOUT, + **kwargs, + ) -> List[str]: + """ + Post data to the HF inference model. It takes in a prompt and returns a list of responses using a REST invocation. + :param data: The data to be sent to the model. + :param attempts: The number of attempts to make. + :param status_codes: The status codes to retry on. + :param timeout: The timeout for the request. + :return: The responses are being returned. + """ + generated_texts: List[str] = [] + if status_codes is None: + status_codes = [429] + try: + response = request_with_retry( + method="POST", + status_codes=status_codes, + attempts=attempts, + url=self.url, + headers=self.headers, + json=data, + timeout=timeout, + ) + output = json.loads(response.text) + generated_texts = [o["generated_text"] for o in output if "generated_text" in o] + except requests.HTTPError as err: + res = err.response + if res.status_code == 429: + raise HuggingFaceInferenceLimitError(f"API rate limit exceeded: {res.text}") + if res.status_code == 401: + raise HuggingFaceInferenceUnauthorizedError(f"API key is invalid: {res.text}") + + raise HuggingFaceInferenceError( + f"HuggingFace Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}", + status_code=res.status_code, + ) + return generated_texts + + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + # TODO: new implementation incoming for all layers, let's omit this for now + return prompt + + @staticmethod + def is_inference_endpoint(model_name_or_path: str) -> bool: + return model_name_or_path is not None and all( + token in model_name_or_path for token in ["https://", "endpoints"] + ) + + @classmethod + def supports(cls, model_name_or_path: str, **kwargs) -> bool: + if cls.is_inference_endpoint(model_name_or_path): + return True + else: + # Check if the model is an HF inference API + task_name: Optional[str] = None + is_inference_api = False + try: + task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None)) + is_inference_api = "api_key" in kwargs + except RuntimeError: + # This will fail for all non-HF models + return False + + return is_inference_api and task_name in ["text2text-generation", "text-generation"] diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 0518add5f..79a4322c2 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -340,7 +340,7 @@ def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs): @pytest.mark.unit def test_update_hash(): - agent = Agent(prompt_node=mock.Mock(), prompt_template=mock.Mock()) + agent = Agent(prompt_node=MockPromptNode(), prompt_template=mock.Mock()) assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e" agent.add_tool( Tool( diff --git a/test/conftest.py b/test/conftest.py index 2f22e3df9..6245f5e50 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -363,6 +363,7 @@ class MockReader(BaseReader): class MockPromptNode(PromptNode): def __init__(self): self.default_prompt_template = None + self.model_name_or_path = "" def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]: return [""]