OpenAI streaming support (#4397)

This commit is contained in:
Vladimir Blagojevic 2023-03-15 18:24:47 +01:00 committed by GitHub
parent 3ecce5cbeb
commit f13501309e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 131 additions and 13 deletions

View File

@ -1,7 +1,9 @@
import json
import logging
from abc import abstractmethod
from abc import abstractmethod, ABC
from typing import Dict, List, Optional, Union, Type
import sseclient
import torch
from transformers import (
pipeline,
@ -27,6 +29,38 @@ from haystack.utils.openai_utils import (
logger = logging.getLogger(__name__)
class TokenStreamingHandler(ABC):
"""
TokenStreamingHandler implementations handle the streaming of tokens from the stream.
"""
DONE_MARKER = "[DONE]"
@abstractmethod
def __call__(self, token_received: str, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the handler.
:return: The token to be sent to the stream.
"""
pass
class DefaultTokenStreamingHandler(TokenStreamingHandler):
def __call__(self, token_received, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the handler.
:return: The token to be sent to the stream.
"""
print(token_received, flush=True, end="")
return token_received
class PromptModelInvocationLayer:
"""
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
@ -341,6 +375,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"frequency_penalty",
"best_of",
"logit_bias",
"stream",
"stream_handler",
]
if key in kwargs
}
@ -385,6 +421,11 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
# either stream is True (will use default handler) or stream_handler is provided
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
"model": self.model_name_or_path,
"prompt": prompt,
@ -393,7 +434,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": False, # no support for streaming
"stream": stream,
"logprobs": kwargs_with_defaults.get("logprobs", None),
"echo": kwargs_with_defaults.get("echo", False),
"stop": kwargs_with_defaults.get("stop", None),
@ -402,10 +443,28 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
"best_of": kwargs_with_defaults.get("best_of", 1),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_text_completion_answers(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
return responses
if not stream:
res = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_text_completion_answers(result=res, payload=payload)
responses = [ans["text"].strip() for ans in res["choices"]]
return responses
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
client = sseclient.SSEClient(response)
tokens: List[str] = []
try:
for event in client.events():
if event.data != TokenStreamingHandler.DONE_MARKER:
ed = json.loads(event.data)
token: str = ed["choices"][0]["text"]
tokens.append(handler(token, event_data=ed["choices"]))
finally:
client.close()
return ["".join(tokens)] # return a list of strings just like non-streaming
def _ensure_token_limit(self, prompt: str) -> str:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.

View File

@ -4,7 +4,7 @@ import logging
import platform
import sys
import json
from typing import Dict, Union, Tuple
from typing import Dict, Union, Tuple, Optional
import requests
from transformers import GPT2TokenizerFast
@ -87,16 +87,25 @@ def _openai_text_completion_tokenization_details(model_name: str):
@retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
)
def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT):
def openai_request(
url: str,
headers: Dict,
payload: Dict,
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
read_response: Optional[bool] = True,
**kwargs,
):
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
:param url: The URL of the OpenAI API.
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`.
:param payload: The payload to send with the request.
:param timeout: The timeout length of the request. The default is 30s.
:param read_response: Whether to read the response as JSON. The default is True.
"""
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout)
res = json.loads(response.text)
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs)
if read_response:
json_response = json.loads(response.text)
if response.status_code != 200:
openai_error: OpenAIError
@ -110,8 +119,10 @@ def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float,
status_code=response.status_code,
)
raise openai_error
return res
if read_response:
return json_response
else:
return response
def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None:

View File

@ -67,6 +67,7 @@ dependencies = [
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
"huggingface-hub>=0.5.0",
"tenacity", # retry decorator
"sseclient-py", # server side events for OpenAI streaming
# Preprocessing
"more_itertools", # for windowing

View File

@ -9,7 +9,7 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.errors import OpenAIError
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt import PromptModelInvocationLayer
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
def skip_test_for_invalid_key(prompt_model):
@ -17,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model):
pytest.skip("No API key found, skipping test")
class TestTokenStreamingHandler(TokenStreamingHandler):
stream_handler_invoked = False
def __call__(self, token_received, *args, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
:param token_received: The token received from the stream.
:param kwargs: Additional keyword arguments passed to the underlying model.
:return: The token to be sent to the stream.
"""
self.stream_handler_invoked = True
return token_received
class CustomInvocationLayer(PromptModelInvocationLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -361,6 +376,38 @@ def test_stop_words(prompt_model):
assert "capital" in r[0] or "Germany" in r[0]
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
def test_streaming_prompt_node_with_params(prompt_model):
skip_test_for_invalid_key(prompt_model)
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
ttsh = TestTokenStreamingHandler()
node = PromptNode(prompt_model)
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
)
def test_streaming_prompt_node():
ttsh = TestTokenStreamingHandler()
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
node = PromptNode(
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
)
response = node("What are some of the best cities in the world to live?")
assert len(response[0]) > 0, "Response should not be empty"
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_simple_pipeline(prompt_model):