mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-15 09:33:34 +00:00
feat: Add chatgpt streaming (#4659)
This commit is contained in:
parent
1dd6158244
commit
6a5acaa1e2
@ -1,8 +1,9 @@
|
|||||||
from typing import Optional, List, Dict, Union
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional, List, Dict, Union, Any
|
||||||
|
|
||||||
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
|
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
||||||
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
||||||
|
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -55,6 +56,10 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
|||||||
kwargs["n"] = top_k
|
kwargs["n"] = top_k
|
||||||
kwargs["best_of"] = top_k
|
kwargs["best_of"] = top_k
|
||||||
kwargs_with_defaults.update(kwargs)
|
kwargs_with_defaults.update(kwargs)
|
||||||
|
|
||||||
|
stream = (
|
||||||
|
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
|
||||||
|
)
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model_name_or_path,
|
"model": self.model_name_or_path,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@ -62,15 +67,22 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
|||||||
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
"temperature": kwargs_with_defaults.get("temperature", 0.7),
|
||||||
"top_p": kwargs_with_defaults.get("top_p", 1),
|
"top_p": kwargs_with_defaults.get("top_p", 1),
|
||||||
"n": kwargs_with_defaults.get("n", 1),
|
"n": kwargs_with_defaults.get("n", 1),
|
||||||
"stream": False, # no support for streaming
|
"stream": stream,
|
||||||
"stop": kwargs_with_defaults.get("stop", None),
|
"stop": kwargs_with_defaults.get("stop", None),
|
||||||
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
|
||||||
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
|
||||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||||
}
|
}
|
||||||
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
if not stream:
|
||||||
_check_openai_finish_reason(result=response, payload=payload)
|
response = openai_request(url=self.url, headers=self.headers, payload=payload)
|
||||||
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
_check_openai_finish_reason(result=response, payload=payload)
|
||||||
|
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
|
||||||
|
else:
|
||||||
|
response = openai_request(
|
||||||
|
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||||
|
)
|
||||||
|
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||||
|
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
|
||||||
|
|
||||||
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
|
||||||
# We want to exclude it to be consistent with other invocation layers
|
# We want to exclude it to be consistent with other invocation layers
|
||||||
@ -82,6 +94,12 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
|||||||
|
|
||||||
return assistant_response
|
return assistant_response
|
||||||
|
|
||||||
|
def _extract_token(self, event_data: Dict[str, Any]):
|
||||||
|
delta = event_data["choices"][0]["delta"]
|
||||||
|
if "content" in delta:
|
||||||
|
return delta["content"]
|
||||||
|
return None
|
||||||
|
|
||||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||||
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
|
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
|
||||||
If needed, truncate the prompt text so that it fits within the limit.
|
If needed, truncate the prompt text so that it fits within the limit.
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Union, Dict, Optional, cast
|
from typing import List, Union, Dict, Optional, cast, Any
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -151,19 +151,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
|||||||
response = openai_request(
|
response = openai_request(
|
||||||
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
|
||||||
client = sseclient.SSEClient(response)
|
return self._process_streaming_response(response=response, stream_handler=handler)
|
||||||
tokens: List[str] = []
|
|
||||||
try:
|
def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler):
|
||||||
for event in client.events():
|
client = sseclient.SSEClient(response)
|
||||||
if event.data != TokenStreamingHandler.DONE_MARKER:
|
tokens: List[str] = []
|
||||||
ed = json.loads(event.data)
|
try:
|
||||||
token: str = ed["choices"][0]["text"]
|
for event in client.events():
|
||||||
tokens.append(handler(token, event_data=ed["choices"]))
|
if event.data != TokenStreamingHandler.DONE_MARKER:
|
||||||
finally:
|
event_data = json.loads(event.data)
|
||||||
client.close()
|
token: str = self._extract_token(event_data)
|
||||||
return ["".join(tokens)] # return a list of strings just like non-streaming
|
if token:
|
||||||
|
tokens.append(stream_handler(token, event_data=event_data["choices"]))
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
return ["".join(tokens)] # return a list of strings just like non-streaming
|
||||||
|
|
||||||
|
def _extract_token(self, event_data: Dict[str, Any]):
|
||||||
|
return event_data["choices"][0]["text"]
|
||||||
|
|
||||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||||
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user