mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-27 09:04:11 +00:00
OpenAI streaming support (#4397)
This commit is contained in:
parent
3ecce5cbeb
commit
f13501309e
@ -1,7 +1,9 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod, ABC
|
||||||
from typing import Dict, List, Optional, Union, Type
|
from typing import Dict, List, Optional, Union, Type
|
||||||
|
|
||||||
|
import sseclient
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
pipeline,
|
pipeline,
|
||||||
@ -27,6 +29,38 @@ from haystack.utils.openai_utils import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenStreamingHandler(ABC):
|
||||||
|
"""
|
||||||
|
TokenStreamingHandler implementations handle the streaming of tokens from the stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DONE_MARKER = "[DONE]"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, token_received: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
This callback method is called when a new token is received from the stream.
|
||||||
|
|
||||||
|
:param token_received: The token received from the stream.
|
||||||
|
:param kwargs: Additional keyword arguments passed to the handler.
|
||||||
|
:return: The token to be sent to the stream.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultTokenStreamingHandler(TokenStreamingHandler):
|
||||||
|
def __call__(self, token_received, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
This callback method is called when a new token is received from the stream.
|
||||||
|
|
||||||
|
:param token_received: The token received from the stream.
|
||||||
|
:param kwargs: Additional keyword arguments passed to the handler.
|
||||||
|
:return: The token to be sent to the stream.
|
||||||
|
"""
|
||||||
|
print(token_received, flush=True, end="")
|
||||||
|
return token_received
|
||||||
|
|
||||||
|
|
||||||
class PromptModelInvocationLayer:
|
class PromptModelInvocationLayer:
|
||||||
"""
|
"""
|
||||||
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
|
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
|
||||||
@ -341,6 +375,8 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
|
"stream",
|
||||||
|
"stream_handler",
|
||||||
]
|
]
|
||||||
if key in kwargs
|
if key in kwargs
|
||||||
}
|
}
|
||||||
@ -385,6 +421,11 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
kwargs["n"] = top_k
|
kwargs["n"] = top_k
|
||||||
kwargs["best_of"] = top_k
|
kwargs["best_of"] = top_k
|
||||||
kwargs_with_defaults.update(kwargs)
|
kwargs_with_defaults.update(kwargs)
|
||||||
|
|
||||||
|
# either stream is True (will use default handler) or stream_handler is provided
|
||||||
|
stream = (
|
||||||
|
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||||
|
)
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model_name_or_path,
|
"model": self.model_name_or_path,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@ -393,7 +434,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||||
"n": kwargs_with_defaults.get("n", 1),
|
"n": kwargs_with_defaults.get("n", 1),
|
||||||
"stream": False, # no support for streaming
|
"stream": stream,
|
||||||
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
"logprobs": kwargs_with_defaults.get("logprobs", None),
|
||||||
"echo": kwargs_with_defaults.get("echo", False),
|
"echo": kwargs_with_defaults.get("echo", False),
|
||||||
"stop": kwargs_with_defaults.get("stop", None),
|
"stop": kwargs_with_defaults.get("stop", None),
|
||||||
@ -402,10 +443,28 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
"best_of": kwargs_with_defaults.get("best_of", 1),
|
"best_of": kwargs_with_defaults.get("best_of", 1),
|
||||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||||
}
|
}
|
||||||
|
if not stream:
|
||||||
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
res = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
_check_openai_text_completion_answers(result=res, payload=payload)
|
_check_openai_text_completion_answers(result=res, payload=payload)
|
||||||
responses = [ans["text"].strip() for ans in res["choices"]]
|
responses = [ans["text"].strip() for ans in res["choices"]]
|
||||||
return responses
|
return responses
|
||||||
|
else:
|
||||||
|
response = openai_request(
|
||||||
|
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||||
|
client = sseclient.SSEClient(response)
|
||||||
|
tokens: List[str] = []
|
||||||
|
try:
|
||||||
|
for event in client.events():
|
||||||
|
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||||
|
ed = json.loads(event.data)
|
||||||
|
token: str = ed["choices"][0]["text"]
|
||||||
|
tokens.append(handler(token, event_data=ed["choices"]))
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||||
|
|
||||||
def _ensure_token_limit(self, prompt: str) -> str:
|
def _ensure_token_limit(self, prompt: str) -> str:
|
||||||
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Union, Tuple
|
from typing import Dict, Union, Tuple, Optional
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
@ -87,16 +87,25 @@ def _openai_text_completion_tokenization_details(model_name: str):
|
|||||||
@retry_with_exponential_backoff(
|
@retry_with_exponential_backoff(
|
||||||
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
|
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
|
||||||
)
|
)
|
||||||
def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT):
|
def openai_request(
|
||||||
|
url: str,
|
||||||
|
headers: Dict,
|
||||||
|
payload: Dict,
|
||||||
|
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
|
||||||
|
read_response: Optional[bool] = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
|
"""Make a request to the OpenAI API given a `url`, `headers`, `payload`, and `timeout`.
|
||||||
|
|
||||||
:param url: The URL of the OpenAI API.
|
:param url: The URL of the OpenAI API.
|
||||||
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`.
|
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||||
:param payload: The payload to send with the request.
|
:param payload: The payload to send with the request.
|
||||||
:param timeout: The timeout length of the request. The default is 30s.
|
:param timeout: The timeout length of the request. The default is 30s.
|
||||||
|
:param read_response: Whether to read the response as JSON. The default is True.
|
||||||
"""
|
"""
|
||||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout)
|
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs)
|
||||||
res = json.loads(response.text)
|
if read_response:
|
||||||
|
json_response = json.loads(response.text)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
openai_error: OpenAIError
|
openai_error: OpenAIError
|
||||||
@ -110,8 +119,10 @@ def openai_request(url: str, headers: Dict, payload: Dict, timeout: Union[float,
|
|||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
raise openai_error
|
raise openai_error
|
||||||
|
if read_response:
|
||||||
return res
|
return json_response
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None:
|
def _check_openai_text_completion_answers(result: Dict, payload: Dict) -> None:
|
||||||
|
@ -67,6 +67,7 @@ dependencies = [
|
|||||||
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
|
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
|
||||||
"huggingface-hub>=0.5.0",
|
"huggingface-hub>=0.5.0",
|
||||||
"tenacity", # retry decorator
|
"tenacity", # retry decorator
|
||||||
|
"sseclient-py", # server side events for OpenAI streaming
|
||||||
|
|
||||||
# Preprocessing
|
# Preprocessing
|
||||||
"more_itertools", # for windowing
|
"more_itertools", # for windowing
|
||||||
|
@ -9,7 +9,7 @@ from haystack import Document, Pipeline, BaseComponent, MultiLabel
|
|||||||
from haystack.errors import OpenAIError
|
from haystack.errors import OpenAIError
|
||||||
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
|
||||||
from haystack.nodes.prompt import PromptModelInvocationLayer
|
from haystack.nodes.prompt import PromptModelInvocationLayer
|
||||||
from haystack.nodes.prompt.providers import HFLocalInvocationLayer
|
from haystack.nodes.prompt.providers import HFLocalInvocationLayer, TokenStreamingHandler
|
||||||
|
|
||||||
|
|
||||||
def skip_test_for_invalid_key(prompt_model):
|
def skip_test_for_invalid_key(prompt_model):
|
||||||
@ -17,6 +17,21 @@ def skip_test_for_invalid_key(prompt_model):
|
|||||||
pytest.skip("No API key found, skipping test")
|
pytest.skip("No API key found, skipping test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenStreamingHandler(TokenStreamingHandler):
|
||||||
|
stream_handler_invoked = False
|
||||||
|
|
||||||
|
def __call__(self, token_received, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
This callback method is called when a new token is received from the stream.
|
||||||
|
|
||||||
|
:param token_received: The token received from the stream.
|
||||||
|
:param kwargs: Additional keyword arguments passed to the underlying model.
|
||||||
|
:return: The token to be sent to the stream.
|
||||||
|
"""
|
||||||
|
self.stream_handler_invoked = True
|
||||||
|
return token_received
|
||||||
|
|
||||||
|
|
||||||
class CustomInvocationLayer(PromptModelInvocationLayer):
|
class CustomInvocationLayer(PromptModelInvocationLayer):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -361,6 +376,38 @@ def test_stop_words(prompt_model):
|
|||||||
assert "capital" in r[0] or "Germany" in r[0]
|
assert "capital" in r[0] or "Germany" in r[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("prompt_model", ["openai", "azure"], indirect=True)
|
||||||
|
def test_streaming_prompt_node_with_params(prompt_model):
|
||||||
|
skip_test_for_invalid_key(prompt_model)
|
||||||
|
|
||||||
|
# test streaming of calls to OpenAI by passing a stream handler to the prompt method
|
||||||
|
ttsh = TestTokenStreamingHandler()
|
||||||
|
node = PromptNode(prompt_model)
|
||||||
|
response = node("What are some of the best cities in the world to live and why?", stream=True, stream_handler=ttsh)
|
||||||
|
|
||||||
|
assert len(response[0]) > 0, "Response should not be empty"
|
||||||
|
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY", None),
|
||||||
|
reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key.",
|
||||||
|
)
|
||||||
|
def test_streaming_prompt_node():
|
||||||
|
ttsh = TestTokenStreamingHandler()
|
||||||
|
|
||||||
|
# test streaming of all calls to OpenAI by registering a stream handler as a model kwarg
|
||||||
|
node = PromptNode(
|
||||||
|
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), model_kwargs={"stream_handler": ttsh}
|
||||||
|
)
|
||||||
|
response = node("What are some of the best cities in the world to live?")
|
||||||
|
|
||||||
|
assert len(response[0]) > 0, "Response should not be empty"
|
||||||
|
assert ttsh.stream_handler_invoked, "Stream handler should have been invoked"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
|
||||||
def test_simple_pipeline(prompt_model):
|
def test_simple_pipeline(prompt_model):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user