mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
feat: Add chatgpt streaming (#4659)
This commit is contained in:
parent
1dd6158244
commit
6a5acaa1e2
@ -1,8 +1,9 @@
|
||||
from typing import Optional, List, Dict, Union
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Union, Any
|
||||
|
||||
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,6 +56,10 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
kwargs["n"] = top_k
|
||||
kwargs["best_of"] = top_k
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
stream = (
|
||||
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||
)
|
||||
payload = {
|
||||
"model": self.model_name_or_path,
|
||||
"messages": messages,
|
||||
@ -62,15 +67,22 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||
"n": kwargs_with_defaults.get("n", 1),
|
||||
"stream": False, # no support for streaming
|
||||
"stream": stream,
|
||||
"stop": kwargs_with_defaults.get("stop", None),
|
||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||
if not stream:
|
||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||
_check_openai_finish_reason(result=response, payload=payload)
|
||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||
else:
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
||||
# We want to exclude it to be consistent with other invocation layers
|
||||
@ -82,6 +94,12 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
|
||||
return assistant_response
|
||||
|
||||
def _extract_token(self, event_data: Dict[str, Any]):
|
||||
delta = event_data["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
return delta["content"]
|
||||
return None
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
|
||||
If needed, truncate the prompt text so that it fits within the limit.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Union, Dict, Optional, cast
|
||||
from typing import List, Union, Dict, Optional, cast, Any
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -151,19 +151,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
response = openai_request(
|
||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||
)
|
||||
|
||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||
client = sseclient.SSEClient(response)
|
||||
tokens: List[str] = []
|
||||
try:
|
||||
for event in client.events():
|
||||
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||
ed = json.loads(event.data)
|
||||
token: str = ed["choices"][0]["text"]
|
||||
tokens.append(handler(token, event_data=ed["choices"]))
|
||||
finally:
|
||||
client.close()
|
||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||
return self._process_streaming_response(response=response, stream_handler=handler)
|
||||
|
||||
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
||||
client = sseclient.SSEClient(response)
|
||||
tokens: List[str] = []
|
||||
try:
|
||||
for event in client.events():
|
||||
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||
event_data = json.loads(event.data)
|
||||
token: str = self._extract_token(event_data)
|
||||
if token:
|
||||
tokens.append(stream_handler(token, event_data=event_data["choices"]))
|
||||
finally:
|
||||
client.close()
|
||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||
|
||||
def _extract_token(self, event_data: Dict[str, Any]):
|
||||
return event_data["choices"][0]["text"]
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user