feat: Add chatgpt streaming (#4659)

This commit is contained in:
Vladimir Blagojevic 2023-04-14 16:02:28 +02:00 committed by GitHub
parent 1dd6158244
commit 6a5acaa1e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 19 deletions

View File

@ -1,8 +1,9 @@
from typing import Optional, List, Dict, Union
import logging
from typing import Optional, List, Dict, Union, Any
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
logger = logging.getLogger(__name__)
@ -55,6 +56,10 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
"model": self.model_name_or_path,
"messages": messages,
@ -62,15 +67,22 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": False, # no support for streaming
"stream": stream,
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
@ -82,6 +94,12 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
return assistant_response
def _extract_token(self, event_data: Dict[str, Any]):
delta = event_data["choices"][0]["delta"]
if "content" in delta:
return delta["content"]
return None
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.

View File

@ -1,4 +1,4 @@
from typing import List, Union, Dict, Optional, cast
from typing import List, Union, Dict, Optional, cast, Any
import json
import logging
@ -151,19 +151,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
client = sseclient.SSEClient(response)
tokens: List[str] = []
try:
for event in client.events():
if event.data != TokenStreamingHandler.DONE_MARKER:
ed = json.loads(event.data)
token: str = ed["choices"][0]["text"]
tokens.append(handler(token, event_data=ed["choices"]))
finally:
client.close()
return ["".join(tokens)] # return a list of strings just like non-streaming
return self._process_streaming_response(response=response, stream_handler=handler)
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
client = sseclient.SSEClient(response)
tokens: List[str] = []
try:
for event in client.events():
if event.data != TokenStreamingHandler.DONE_MARKER:
event_data = json.loads(event.data)
token: str = self._extract_token(event_data)
if token:
tokens.append(stream_handler(token, event_data=event_data["choices"]))
finally:
client.close()
return ["".join(tokens)] # return a list of strings just like non-streaming
def _extract_token(self, event_data: Dict[str, Any]):
return event_data["choices"][0]["text"]
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.