mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 00:24:14 +00:00
OpenAI streaming support (#4397)
This commit is contained in:
parent
3ecce5cbeb
commit
f13501309e
@ -1,7 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Dict, List, Optional, Union, Type
|
||||
|
||||
import sseclient
|
||||
import torch
|
||||
from transformers import (
|
||||
pipeline,
|
||||
@ -27,6 +29,38 @@ from haystack.utils.openai_utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenStreamingHandler(ABC):
|
||||
"""
|
||||
TokenStreamingHandler implementations handle the streaming of tokens from the stream.
|
||||
"""
|
||||
|
||||
DONE_MARKER = "[DONE]"
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, token_received: str, **kwargs) -> str:
|
||||
"""
|
||||
This callback method is called when a new token is received from the stream.
|
||||
|
||||
:param token_received: The token received from the stream.
|
||||
:param kwargs: Additional keyword arguments passed to the handler.
|
||||
:return: The token to be sent to the stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultTokenStreamingHandler(TokenStreamingHandler):
|
||||
def __call__(self, token_received, **kwargs) -> str:
|
||||
"""
|
||||
This callback method is called when a new token is received from the stream.
|
||||
|
||||
:param token_received: The token received from the stream.
|
||||
:param kwargs: Additional keyword arguments passed to the handler.
|
||||
:return: The token to be sent to the stream.
|
||||
"""
|
||||
print(token_received, flush=True, end="")
|
||||
return token_received
|
||||
|
||||
|
||||
class PromptModelInvocationLayer:
|
||||
"""
|
||||
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
|
||||
@ -341,6 +375,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"frequency_penalty",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
"stream",
|
||||
"stream_handler",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
@ -385,6 +421,11 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
kwargs["n"] = top_k
|
||||
kwargs["best_of"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
# either stream is True (will use default handler) or stream_handler is provided
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
payload = {
|
||||
"model": self.model_name_or_path,
|
||||
"prompt": prompt,
|
||||
@ -393,7 +434,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||
"n": kwargs_with_defaults.get("n", 1),
|
||||
"stream": False, # no support for streaming
|
||||
"stream": stream,
|
||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||
"echo": kwargs_with_defaults.get("echo", False),
|
||||
"stop": kwargs_with_defaults.get("stop", None),
|
||||
@ -402,10 +443,28 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_text_completion_answers(result=res, payload=payload)
|
||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||
return responses
|
||||
if not stream:
|
||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_text_completion_answers(result=res, payload=payload)
|
||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||
return responses
|
||||
else:
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
client = sseclient.SSEClient(response)
|
||||
tokens: List[str] = []
|
||||
try:
|
||||
for event in client.events():
|
||||
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||
ed = json.loads(event.data)
|
||||
token: str = ed["choices"][0]["text"]
|
||||
tokens.append(handler(token, event_data=ed["choices"]))
|
||||
finally:
|
||||
client.close()
|
||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||
|
||||
def _ensure_token_limit(self, prompt: str) -> str:
|
||||
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import platform
|
||||
import sys
|
||||
import json
|
||||
from typing import Dict, Union, Tuple
|
||||
from typing import Dict, Union, Tuple, Optional
|
||||
import requests
|
||||
|
||||
from transformers import GPT2TokenizerFast
|
||||
@ -87,16 +87,25 @@ def _openai_text_completion_tokenization_details(model_name: str):
|
||||
@retry_with_exponential_backoff(
|
||||
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
|
||||
)
|
||||
def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT):
|
||||
def openai_request(
|
||||
url: str,
|
||||
headers: Dict,
|
||||
payload: Dict,
|
||||
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
|
||||
read_response: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
|
||||
|
||||
:param url: The URL of the OpenAI API.
|
||||
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param payload: The payload to send with the request.
|
||||
:param timeout: The timeout length of the request. The default is 30s.
|
||||
:param read_response: Whether to read the response as JSON. The default is True.
|
||||
"""
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout)
|
||||
res = json.loads(response.text)
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs)
|
||||
if read_response:
|
||||
json_response = json.loads(response.text)
|
||||
|
||||
if response.status_code != 200:
|
||||
openai_error: OpenAIError
|
||||
@ -110,8 +119,10 @@ def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
raise openai_error
|
||||
|
||||
return res
|
||||
if read_response:
|
||||
return json_response
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None:
|
||||
|
@ -67,6 +67,7 @@ dependencies = [
|
||||
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
|
||||
"huggingface-hub>=0.5.0",
|
||||
"tenacity", # retry decorator
|
||||
"sseclient-py", # server side events for OpenAI streaming
|
||||
|
||||
# Preprocessing
|
||||
"more_itertools", # for windowing
|
||||
|
@ -9,7 +9,7 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
||||
from haystack.errors import OpenAIError
|
||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||
from haystack.nodes.prompt import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
|
||||
|
||||
|
||||
def skip_test_for_invalid_key(prompt_model):
|
||||
@ -17,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model):
|
||||
pytest.skip("No API key found, skipping test")
|
||||
|
||||
|
||||
class TestTokenStreamingHandler(TokenStreamingHandler):
|
||||
stream_handler_invoked = False
|
||||
|
||||
def __call__(self, token_received, *args, **kwargs) -> str:
|
||||
"""
|
||||
This callback method is called when a new token is received from the stream.
|
||||
|
||||
:param token_received: The token received from the stream.
|
||||
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||
:return: The token to be sent to the stream.
|
||||
"""
|
||||
self.stream_handler_invoked = True
|
||||
return token_received
|
||||
|
||||
|
||||
class CustomInvocationLayer(PromptModelInvocationLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -361,6 +376,38 @@ def test_stop_words(prompt_model):
|
||||
assert "capital" in r[0] or "Germany" in r[0]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||
def test_streaming_prompt_node_with_params(prompt_model):
|
||||
skip_test_for_invalid_key(prompt_model)
|
||||
|
||||
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
|
||||
ttsh = TestTokenStreamingHandler()
|
||||
node = PromptNode(prompt_model)
|
||||
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
|
||||
|
||||
assert len(response[0]) > 0, "Response should not be empty"
|
||||
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
|
||||
)
|
||||
def test_streaming_prompt_node():
|
||||
ttsh = TestTokenStreamingHandler()
|
||||
|
||||
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
|
||||
node = PromptNode(
|
||||
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
|
||||
)
|
||||
response = node("What are some of the best cities in the world to live?")
|
||||
|
||||
assert len(response[0]) > 0, "Response should not be empty"
|
||||
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||
def test_simple_pipeline(prompt_model):
|
||||
|
Loading…
x
Reference in New Issue
Block a user