diff --git a/haystack/nodes/prompt/providers.py b/haystack/nodes/prompt/providers.py index 315eb2680..d945da0c1 100644 --- a/haystack/nodes/prompt/providers.py +++ b/haystack/nodes/prompt/providers.py @@ -1,7 +1,9 @@ +import json import logging -from abc import abstractmethod +from abc import abstractmethod, ABC from typing import Dict, List, Optional, Union, Type +import sseclient import torch from transformers import ( pipeline, @@ -27,6 +29,38 @@ from haystack.utils.openai_utils import ( logger = logging.getLogger(__name__) +class TokenStreamingHandler(ABC): + """ + TokenStreamingHandler implementations handle the streaming of tokens from the stream. + """ + + DONE_MARKER = "[DONE]" + + @abstractmethod + def __call__(self, token_received: str, **kwargs) -> str: + """ + This callback method is called when a new token is received from the stream. + + :param token_received: The token received from the stream. + :param kwargs: Additional keyword arguments passed to the handler. + :return: The token to be sent to the stream. + """ + pass + + +class DefaultTokenStreamingHandler(TokenStreamingHandler): + def __call__(self, token_received, **kwargs) -> str: + """ + This callback method is called when a new token is received from the stream. + + :param token_received: The token received from the stream. + :param kwargs: Additional keyword arguments passed to the handler. + :return: The token to be sent to the stream. + """ + print(token_received, flush=True, end="") + return token_received + + class PromptModelInvocationLayer: """ PromptModelInvocationLayer implementations execute a prompt on an underlying model. @@ -341,6 +375,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): "frequency_penalty", "best_of", "logit_bias", + "stream", + "stream_handler", ] if key in kwargs } @@ -385,6 +421,11 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): kwargs["n"] = top_k kwargs["best_of"] = top_k kwargs_with_defaults.update(kwargs) + + # either stream is True (will use default handler) or stream_handler is provided + stream = ( + kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None + ) payload = { "model": self.model_name_or_path, "prompt": prompt, @@ -393,7 +434,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): "temperature": kwargs_with_defaults.get("temperature", 0.7), "top_p": kwargs_with_defaults.get("top_p", 1), "n": kwargs_with_defaults.get("n", 1), - "stream": False, # no support for streaming + "stream": stream, "logprobs": kwargs_with_defaults.get("logprobs", None), "echo": kwargs_with_defaults.get("echo", False), "stop": kwargs_with_defaults.get("stop", None), @@ -402,10 +443,28 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): "best_of": kwargs_with_defaults.get("best_of", 1), "logit_bias": kwargs_with_defaults.get("logit_bias", {}), } - res = openai_request(url=self.url, headers=self.headers, payload=payload) - _check_openai_text_completion_answers(result=res, payload=payload) - responses = [ans["text"].strip() for ans in res["choices"]] - return responses + if not stream: + res = openai_request(url=self.url, headers=self.headers, payload=payload) + _check_openai_text_completion_answers(result=res, payload=payload) + responses = [ans["text"].strip() for ans in res["choices"]] + return responses + else: + response = openai_request( + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True + ) + + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + client = sseclient.SSEClient(response) + tokens: List[str] = [] + try: + for event in client.events(): + if event.data != TokenStreamingHandler.DONE_MARKER: + ed = json.loads(event.data) + token: str = ed["choices"][0]["text"] + tokens.append(handler(token, event_data=ed["choices"])) + finally: + client.close() + return ["".join(tokens)] # return a list of strings just like non-streaming def _ensure_token_limit(self, prompt: str) -> str: """Ensure that the length of the prompt and answer is within the max tokens limit of the model. diff --git a/haystack/utils/openai_utils.py b/haystack/utils/openai_utils.py index 320352412..939f58e08 100644 --- a/haystack/utils/openai_utils.py +++ b/haystack/utils/openai_utils.py @@ -4,7 +4,7 @@ import logging import platform import sys import json -from typing import Dict, Union, Tuple +from typing import Dict, Union, Tuple, Optional import requests from transformers import GPT2TokenizerFast @@ -87,16 +87,25 @@ def _openai_text_completion_tokenization_details(model_name: str): @retry_with_exponential_backoff( backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError) ) -def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT): +def openai_request( + url: str, + headers: Dict, + payload: Dict, + timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT, + read_response: Optional[bool] = True, + **kwargs, +): """Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`. :param url: The URL of the OpenAI API. :param headers: Dictionary of HTTP Headers to send with the :class:`Request`. :param payload: The payload to send with the request. :param timeout: The timeout length of the request. The default is 30s. + :param read_response: Whether to read the response as JSON. The default is True. """ - response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout) - res = json.loads(response.text) + response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs) + if read_response: + json_response = json.loads(response.text) if response.status_code != 200: openai_error: OpenAIError @@ -110,8 +119,10 @@ def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, status_code=response.status_code, ) raise openai_error - - return res + if read_response: + return json_response + else: + return response def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None: diff --git a/pyproject.toml b/pyproject.toml index 6e97501f6..ca757868c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ dependencies = [ # audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader "huggingface-hub>=0.5.0", "tenacity", # retry decorator + "sseclient-py", # server side events for OpenAI streaming # Preprocessing "more_itertools", # for windowing diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py index 1368034bb..b23ec9cd7 100644 --- a/test/nodes/test_prompt_node.py +++ b/test/nodes/test_prompt_node.py @@ -9,7 +9,7 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel from haystack.errors import OpenAIError from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt import PromptModelInvocationLayer -from haystack.nodes.prompt.providers import HFLocalInvocationLayer +from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler def skip_test_for_invalid_key(prompt_model): @@ -17,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model): pytest.skip("No API key found, skipping test") +class TestTokenStreamingHandler(TokenStreamingHandler): + stream_handler_invoked = False + + def __call__(self, token_received, *args, **kwargs) -> str: + """ + This callback method is called when a new token is received from the stream. + + :param token_received: The token received from the stream. + :param kwargs: Additional keyword arguments passed to the underlying model. + :return: The token to be sent to the stream. + """ + self.stream_handler_invoked = True + return token_received + + class CustomInvocationLayer(PromptModelInvocationLayer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -361,6 +376,38 @@ def test_stop_words(prompt_model): assert "capital" in r[0] or "Germany" in r[0] +@pytest.mark.integration +@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True) +def test_streaming_prompt_node_with_params(prompt_model): + skip_test_for_invalid_key(prompt_model) + + # test streaming of calls to OpenAI by passing a stream handler to the prompt method + ttsh = TestTokenStreamingHandler() + node = PromptNode(prompt_model) + response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh) + + assert len(response[0]) > 0, "Response should not be empty" + assert ttsh.stream_handler_invoked, "Stream handler should have been invoked" + + +@pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.", +) +def test_streaming_prompt_node(): + ttsh = TestTokenStreamingHandler() + + # test streaming of all calls to OpenAI by registering a stream handler as a model kwarg + node = PromptNode( + "text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh} + ) + response = node("What are some of the best cities in the world to live?") + + assert len(response[0]) > 0, "Response should not be empty" + assert ttsh.stream_handler_invoked, "Stream handler should have been invoked" + + @pytest.mark.integration @pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True) def test_simple_pipeline(prompt_model):