Michael Feil 6ea8ae01a2
feat: Allow setting custom api_base for OpenAI nodes (#5033)
* add changes for api_base

* format retriever

* Update haystack/nodes/retriever/dense.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/audio/whisper_transcriber.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/preview/components/audio/whisper_remote.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/answer_generator/openai.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update test_retriever.py

* Update test_whisper_remote.py

* Update test_generator.py

* Update test_retriever.py

* reformat with black

* Update haystack/nodes/prompt/invocation_layer/chatgpt.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Add unit tests

* apply docstring suggestions

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
Co-authored-by: michaelfeil <me@michaelfeil.eu>
Co-authored-by: Daria Fokina <daria.f93@gmail.com>
2023-06-05 11:32:06 +02:00

148 lines
7.0 KiB
Python

import logging
from typing import Optional, List, Dict, Union, Any
from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
logger = logging.getLogger(__name__)
class ChatGPTInvocationLayer(OpenAIInvocationLayer):
"""
ChatGPT Invocation Layer
PromptModelInvocationLayer implementation for OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API.
See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details.
Note: kwargs other than init parameter names are ignored to enable reflective construction of the class
as many variants of PromptModelInvocationLayer are possible and they may have different parameters.
"""
def __init__(
self,
api_key: str,
model_name_or_path: str = "gpt-3.5-turbo",
max_length: Optional[int] = 500,
api_base: str = "https://api.openai.com/v1",
**kwargs,
):
"""
Creates an instance of ChatGPTInvocationLayer for OpenAI's GPT-3.5 GPT-4 models.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum number of tokens the output text can have.
:param api_key: The OpenAI API key.
:param api_base: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Additional keyword arguments passed to the underlying model.
[See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
"""
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
def invoke(self, *args, **kwargs):
"""
It takes in either a prompt or a list of messages and returns a list of responses, using a REST invocation.
:return: A list of generated responses.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
"""
prompt = kwargs.get("prompt", None)
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
kwargs_with_defaults = self.model_input_kwargs
if kwargs:
# we use keyword stop_words but OpenAI uses stop
if "stop_words" in kwargs:
kwargs["stop"] = kwargs.pop("stop_words")
if "top_k" in kwargs:
top_k = kwargs.pop("top_k")
kwargs["n"] = top_k
kwargs["best_of"] = top_k
kwargs_with_defaults.update(kwargs)
stream = (
kwargs_with_defaults.get("stream", False) or kwargs_with_defaults.get("stream_handler", None) is not None
)
payload = {
"model": self.model_name_or_path,
"messages": messages,
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"temperature": kwargs_with_defaults.get("temperature", 0.7),
"top_p": kwargs_with_defaults.get("top_p", 1),
"n": kwargs_with_defaults.get("n", 1),
"stream": stream,
"stop": kwargs_with_defaults.get("stop", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", 0),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0),
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)
# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()
return assistant_response
def _extract_token(self, event_data: Dict[str, Any]):
delta = event_data["choices"][0]["delta"]
if "content" in delta:
return delta["content"]
return None
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
"""
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
n_prompt_tokens = count_openai_tokens_messages(messages, self._tokenizer)
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= self.max_tokens_limit:
return prompt
# TODO: support truncation as in _ensure_token_limit methods for other invocation layers
raise ValueError(
f"The prompt or the messages are too long ({n_prompt_tokens} tokens). "
f"The length of the prompt or messages and the answer ({n_answer_tokens} tokens) should be within the max token limit ({self.max_tokens_limit} tokens). "
f"Reduce the length of the prompt or messages."
)
@property
def url(self) -> str:
return f"{self.api_base}/chat/completions"
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]