OpenAI streaming support (#4397)

This commit is contained in:
Vladimir Blagojevic 2023-03-15 18:24:47 +01:00 committed by GitHub
parent 3ecce5cbeb
commit f13501309e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 131 additions and 13 deletions

View File

@ -1,7 +1,9 @@
import json
import logging import logging
from abc import abstractmethod from abc import abstractmethod, ABC
from typing import Dict, List, Optional, Union, Type from typing import Dict, List, Optional, Union, Type
import sseclient
import torch import torch
from transformers import ( from transformers import (
pipeline, pipeline,
@ -27,6 +29,38 @@ from haystack.utils.openai_utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TokenStreamingHandler(ABC):
"""
TokenStreamingHandler implementations handle the streaming of tokens from the stream.
"""
DONE_MARKER = "[DONE]"
@abstractmethod
def __call__(self, token_received: str, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the handler.
:return: The token to be sent to the stream.
"""
pass
class DefaultTokenStreamingHandler(TokenStreamingHandler):
def __call__(self, token_received, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the handler.
:return: The token to be sent to the stream.
"""
print(token_received, flush=True, end="")
return token_received
class PromptModelInvocationLayer: class PromptModelInvocationLayer:
""" """
PromptModelInvocationLayer implementations execute a prompt on an underlying model. PromptModelInvocationLayer implementations execute a prompt on an underlying model.
@ -341,6 +375,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"frequency_penalty", "frequency_penalty",
"best_of", "best_of",
"logit_bias", "logit_bias",
"stream",
"stream_handler",
] ]
if key in kwargs if key in kwargs
} }
@ -385,6 +421,11 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs["n"] = top_k kwargs["n"] = top_k
kwargs["best_of"] = top_k kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs) kwargs_with_defaults.update(kwargs)
# either stream is True (will use default handler) or stream_handler is provided
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = { payload = {
"model": self.model_name_or_path, "model": self.model_name_or_path,
"prompt": prompt, "prompt": prompt,
@ -393,7 +434,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"temperature": kwargs_with_defaults.get("temperature", 0.7), "temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1), "top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1), "n": kwargs_with_defaults.get("n", 1),
"stream": False, # no support for streaming "stream": stream,
"logprobs": kwargs_with_defaults.get("logprobs", None), "logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False), "echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None), "stop": kwargs_with_defaults.get("stop", None),
@ -402,10 +443,28 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"best_of": kwargs_with_defaults.get("best_of", 1), "best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}), "logit_bias": kwargs_with_defaults.get("logit_bias", {}),
} }
res = openai_request(url=self.url, headers=self.headers, payload=payload) if not stream:
_check_openai_text_completion_answers(result=res, payload=payload) res = openai_request(url=self.url, headers=self.headers, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]] _check_openai_text_completion_answers(result=res, payload=payload)
return responses responses = [ans["text"].strip() for ans in res["choices"]]
return responses
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
client = sseclient.SSEClient(response)
tokens: List[str] = []
try:
for event in client.events():
if event.data != TokenStreamingHandler.DONE_MARKER:
ed = json.loads(event.data)
token: str = ed["choices"][0]["text"]
tokens.append(handler(token, event_data=ed["choices"]))
finally:
client.close()
return ["".join(tokens)] # return a list of strings just like non-streaming
def _ensure_token_limit(self, prompt: str) -> str: def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model. """Ensure that the length of the prompt and answer is within the max tokens limit of the model.

View File

@ -4,7 +4,7 @@ import logging
import platform import platform
import sys import sys
import json import json
from typing import Dict, Union, Tuple from typing import Dict, Union, Tuple, Optional
import requests import requests
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
@ -87,16 +87,25 @@ def _openai_text_completion_tokenization_details(model_name: str):
@retry_with_exponential_backoff( @retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError) backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
) )
def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT): def openai_request(
url: str,
headers: Dict,
payload: Dict,
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
read_response: Optional[bool] = True,
**kwargs,
):
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`. """Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
:param url: The URL of the OpenAI API. :param url: The URL of the OpenAI API.
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`. :param headers: Dictionary of HTTP Headers to send with the :class:`Request`.
:param payload: The payload to send with the request. :param payload: The payload to send with the request.
:param timeout: The timeout length of the request. The default is 30s. :param timeout: The timeout length of the request. The default is 30s.
:param read_response: Whether to read the response as JSON. The default is True.
""" """
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout) response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs)
res = json.loads(response.text) if read_response:
json_response = json.loads(response.text)
if response.status_code != 200: if response.status_code != 200:
openai_error: OpenAIError openai_error: OpenAIError
@ -110,8 +119,10 @@ def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float,
status_code=response.status_code, status_code=response.status_code,
) )
raise openai_error raise openai_error
if read_response:
return res return json_response
else:
return response
def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None: def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None:

View File

@ -67,6 +67,7 @@ dependencies = [
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader # audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
"huggingface-hub>=0.5.0", "huggingface-hub>=0.5.0",
"tenacity", # retry decorator "tenacity", # retry decorator
"sseclient-py", # server side events for OpenAI streaming
# Preprocessing # Preprocessing
"more_itertools", # for windowing "more_itertools", # for windowing

View File

@ -9,7 +9,7 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt import PromptModelInvocationLayer from haystack.nodes.prompt import PromptModelInvocationLayer
from haystack.nodes.prompt.providers import HFLocalInvocationLayer from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
def skip_test_for_invalid_key(prompt_model): def skip_test_for_invalid_key(prompt_model):
@ -17,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model):
pytest.skip("No API key found, skipping test") pytest.skip("No API key found, skipping test")
class TestTokenStreamingHandler(TokenStreamingHandler):
stream_handler_invoked = False
def __call__(self, token_received, *args, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the underlying model.
:return: The token to be sent to the stream.
"""
self.stream_handler_invoked = True
return token_received
class CustomInvocationLayer(PromptModelInvocationLayer): class CustomInvocationLayer(PromptModelInvocationLayer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -361,6 +376,38 @@ def test_stop_words(prompt_model):
assert "capital" in r[0] or "Germany" in r[0] assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_streaming_prompt_node_with_params(prompt_model):
skip_test_for_invalid_key(prompt_model)
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
ttsh = TestTokenStreamingHandler()
node = PromptNode(prompt_model)
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
)
def test_streaming_prompt_node():
ttsh = TestTokenStreamingHandler()
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
node = PromptNode(
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
)
response = node("What are some of the best cities in the world to live?")
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True) @pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_simple_pipeline(prompt_model): def test_simple_pipeline(prompt_model):