diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index a88567f7b..c47bfeedf 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -1,8 +1,9 @@ -from typing import Optional, List, Dict, Union import logging +from typing import Optional, List, Dict, Union, Any -from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages +from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer +from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages logger = logging.getLogger(__name__) @@ -55,6 +56,10 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): kwargs["n"] = top_k kwargs["best_of"] = top_k kwargs_with_defaults.update(kwargs) + + stream = ( + kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None + ) payload = { "model": self.model_name_or_path, "messages": messages, @@ -62,15 +67,22 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): "temperature": kwargs_with_defaults.get("temperature", 0.7), "top_p": kwargs_with_defaults.get("top_p", 1), "n": kwargs_with_defaults.get("n", 1), - "stream": False, # no support for streaming + "stream": stream, "stop": kwargs_with_defaults.get("stop", None), "presence_penalty": kwargs_with_defaults.get("presence_penalty", 0), "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0), "logit_bias": kwargs_with_defaults.get("logit_bias", {}), } - response = openai_request(url=self.url, headers=self.headers, payload=payload) - _check_openai_finish_reason(result=response, payload=payload) - assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] + if not stream: + response = openai_request(url=self.url, headers=self.headers, payload=payload) + _check_openai_finish_reason(result=response, payload=payload) + assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] + else: + response = openai_request( + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True + ) + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + assistant_response = self._process_streaming_response(response=response, stream_handler=handler) # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word # We want to exclude it to be consistent with other invocation layers @@ -82,6 +94,12 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): return assistant_response + def _extract_token(self, event_data: Dict[str, Any]): + delta = event_data["choices"][0]["delta"] + if "content" in delta: + return delta["content"] + return None + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: """Make sure the length of the prompt and answer is within the max tokens limit of the model. If needed, truncate the prompt text so that it fits within the limit. diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index 0e9bd422f..32d1564dc 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -1,4 +1,4 @@ -from typing import List, Union, Dict, Optional, cast +from typing import List, Union, Dict, Optional, cast, Any import json import logging @@ -151,19 +151,25 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): response = openai_request( url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True ) - handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) - client = sseclient.SSEClient(response) - tokens: List[str] = [] - try: - for event in client.events(): - if event.data != TokenStreamingHandler.DONE_MARKER: - ed = json.loads(event.data) - token: str = ed["choices"][0]["text"] - tokens.append(handler(token, event_data=ed["choices"])) - finally: - client.close() - return ["".join(tokens)] # return a list of strings just like non-streaming + return self._process_streaming_response(response=response, stream_handler=handler) + + def _process_streaming_response(self, response, stream_handler: TokenStreamingHandler): + client = sseclient.SSEClient(response) + tokens: List[str] = [] + try: + for event in client.events(): + if event.data != TokenStreamingHandler.DONE_MARKER: + event_data = json.loads(event.data) + token: str = self._extract_token(event_data) + if token: + tokens.append(stream_handler(token, event_data=event_data["choices"])) + finally: + client.close() + return ["".join(tokens)] # return a list of strings just like non-streaming + + def _extract_token(self, event_data: Dict[str, Any]): + return event_data["choices"][0]["text"] def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: """Ensure that the length of the prompt and answer is within the max tokens limit of the model.